Skip to content

Commit e56c3ae

Browse files
committed
Add equality function for Node type
1 parent 6cdd2a5 commit e56c3ae

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

src/Equation.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,4 +396,34 @@ function Base.hash(tree::Node{T})::UInt where {T}
396396
end
397397
end
398398

399+
function is_equal(a::Node{T}, b::Node{T})::Bool where {T}
400+
if a.degree == 0
401+
b.degree != 0 && return false
402+
if a.constant
403+
!(b.constant) && return false
404+
return a.val::T == b.val::T
405+
else
406+
b.constant && return false
407+
return a.feature == b.feature
408+
end
409+
elseif a.degree == 1
410+
b.degree != 1 && return false
411+
a.op != b.op && return false
412+
return is_equal(a.l, b.l)
413+
else
414+
b.degree != 2 && return false
415+
a.op != b.op && return false
416+
return is_equal(a.l, b.l) && is_equal(a.r, b.r)
417+
end
418+
end
419+
420+
function Base.:(==)(a::Node{T}, b::Node{T})::Bool where {T}
421+
return is_equal(a, b)
422+
end
423+
424+
function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
425+
T = promote_type(T1, T2)
426+
return is_equal(convert(Node{T}, a), convert(Node{T}, b))
427+
end
428+
399429
end

test/test_equality.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using DynamicExpressions
2+
using Test
3+
4+
operators = OperatorEnum(;
5+
binary_operators=[+, *, -, /], unary_operators=[sin, cos, exp, log]
6+
)
7+
8+
# Create a big expression, using those operators:
9+
x1 = Node(; feature=1)
10+
x2 = Node(; feature=2)
11+
x3 = Node(; feature=3)
12+
13+
tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1)
14+
same_tree = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 / x1)
15+
@test tree == same_tree
16+
17+
copied_tree = copy_node(tree; preserve_topology=true)
18+
@test tree == copied_tree
19+
20+
copied_tree2 = copy_node(tree; preserve_topology=false)
21+
@test tree == copied_tree2
22+
23+
modifed_tree = x1 + x2 * x1 - log(x2 * 3.2) + 1.5 * cos(x2 / x1)
24+
@test tree != modifed_tree
25+
modifed_tree2 = x1 + x2 * x3 - log(x2 * 3.1) + 1.5 * cos(x2 / x1)
26+
@test tree != modifed_tree2
27+
modifed_tree3 = x1 + x2 * x3 - exp(x2 * 3.2) + 1.5 * cos(x2 / x1)
28+
@test tree != modifed_tree3
29+
modified_tree4 = x1 + x2 * x3 - log(x2 * 3.2) + 1.5 * cos(x2 * x1)
30+
@test tree != modified_tree4
31+
32+
# Order matters!
33+
modified_tree5 = 1.5 * cos(x2 * x1) + x1 + x2 * x3 - log(x2 * 3.2)
34+
@test tree != modified_tree5
35+
36+
# Type should not matter if equivalent in the promoted type:
37+
f64_tree = x1 + x2 * x3 - log(x2 * 3.0) + 1.5 * cos(x2 / x1)
38+
f32_tree = x1 + x2 * x3 - log(x2 * 3.0f0) + 1.5f0 * cos(x2 / x1)
39+
@test typeof(f64_tree) == Node{Float64}
40+
@test typeof(f32_tree) == Node{Float32}
41+
42+
@test convert(Node{Float64}, f32_tree) == f64_tree
43+
44+
@test f64_tree == f32_tree

0 commit comments

Comments
 (0)