Skip to content

Commit f7a0362

Browse files
committed
added type constrains for expression evaluation
1 parent 082ae93 commit f7a0362

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

src/api/expressions.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ Julia expression wrapper
33
"""
44
struct Expression
55
expr::Expr
6-
syms::Vector{Symbol}
6+
syms::Dict{Symbol,Int}
7+
end
8+
function Expression(ex::Expr)
9+
syms = Dict( s=>i for (i,s) in pairs(sort!(symbols(ex))) )
10+
Expression(ex, syms)
711
end
8-
Expression(ex::Expr) = Expression(ex, sort!(symbols(ex)))
912
show(io::IO, e::Expression) = infix(io, e.expr)
10-
(e::Expression)(val) = evaluate(val, e.expr, e.syms)
13+
(e::Expression)(vals::T...) where {T} = evaluate(e.expr, e.syms, vals...)
1114

1215
function symbols(ex::Expr)
1316
syms = Symbol[]
@@ -18,6 +21,11 @@ function symbols(ex::Expr)
1821
unique!(syms)
1922
end
2023

24+
function compile(ex::Expr, params::Vector{Symbol})
25+
tprm = Expr(:tuple, params...)
26+
Expr(:->, tprm, ex)
27+
end
28+
2129
height(ex) = isa(ex, Expr) ? maximum( height(e) for e in ex.args )+1 : 0
2230
nodes(ex) = !isa(ex, Expr) ? 1 : length(ex.args) > 0 ? sum( nodes(e) for e in ex.args ) : 0
2331
length(ex) = nodes(ex)
@@ -79,15 +87,18 @@ isdiv(ex) = ex in [/, div, pdiv, aq]
7987
isexpr(ex) = isa(ex, Expr)
8088
issym(ex) = isa(ex, Symbol)
8189

82-
function evaluate(val, ex::Expr, psyms::Vector{Symbol})
90+
function evaluate(ex::Expr, psyms::Dict{Symbol,Int}, vals::T...)::T where {T}
8391
exprm = ex.args
84-
exvals = (isexpr(nex) || issym(nex) ? evaluate(val, nex, psyms) : nex for nex in exprm[2:end])
92+
exvals = (isexpr(nex) || issym(nex) ? evaluate(nex, psyms, vals...) : nex for nex in exprm[2:end])
8593
exprm[1](exvals...)
8694
end
8795

88-
function evaluate(val, ex::Symbol, psyms::Vector{Symbol})
89-
pidx = findfirst(isequal(ex), psyms)
90-
val[pidx]
96+
function evaluate(ex::Symbol, psyms::Dict{Symbol,Int}, vals::T...)::T where {T}
97+
pidx = get(psyms, ex, 0)
98+
if pidx == 0
99+
@error "Undefined symbol: $ex"
100+
end
101+
return T(vals[pidx])
91102
end
92103

93104

@@ -114,9 +125,9 @@ function simplifybinary!(root)
114125
root = 1
115126
elseif (fn == (*) || isdiv(fn)) && (iszeronum(op1) || iszeronum(op2)) # look for 0: x*0 = 0*x = 0
116127
root = 0
117-
elseif (fn == (+) || fn == (-)) && iszeronum(op2) # look for 0: x ± 0 = 0
128+
elseif (fn == (+) || fn == (-)) && iszeronum(op2) # look for 0: x ± 0 = x
118129
root = op1
119-
elseif fn == (+) && iszeronum(op1) # look for 0: 0 ± x = 0
130+
elseif fn == (+) && iszeronum(op1) # look for 0: 0 + x = x
120131
root = op2
121132
elseif (fn == (*) || isdiv(fn)) && isonenum(op2) # x*1 = x || x/1 = x
122133
root = op1

test/gp.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,27 @@
8282
ys = fitfun.(rg)
8383
function fitobj(expr)
8484
ex = Evolutionary.Expression(expr)
85-
sum(v->isnan(v) ? 1.0 : v, abs2.(ys - ex.(rg)) )/length(rg)
85+
#ex = Evolutionary.Expression(expr, Dict(:x=>1))
86+
yy = ex.(rg)
87+
sum(v->isnan(v) ? 1.0 : v, abs2.(ys .- yy) )/length(rg)
8688
end
8789

8890
Random.seed!(rng, 42)
8991
res = Evolutionary.optimize(fitobj,
9092
TreeGP(25, Terminal[:x, randn], Function[+,-,*,Evolutionary.aq],
9193
mindepth=1,
9294
maxdepth=3,
93-
simplify = simplify!,
95+
simplify = Evolutionary.simplify!,
9496
optimizer = GA(
9597
selection = tournament(3),
96-
mutationRate = 0.2,
98+
mutationRate = 0.1,
9799
crossoverRate = 0.9,
100+
ε = 0.1
98101
),
99102
),
100-
Evolutionary.Options(show_trace=false, rng=rng, iterations=50)
103+
Evolutionary.Options(show_trace=true, rng=rng, iterations=50)
101104
)
102105
@test minimum(res) < 1.1
103106

104107
end
108+

0 commit comments

Comments
 (0)