Skip to content

Commit f8f9678

Browse files
committed
Expand testing
1 parent 76655d3 commit f8f9678

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

src/Equation.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,18 +226,12 @@ function (::Type{N})(
226226
op::Integer, l::AbstractExpressionNode{T}
227227
) where {T,N<:AbstractExpressionNode}
228228
@assert l isa N
229-
if !(N isa UnionAll)
230-
@warn "Ignoring specified type parameters in binary operator constructor."
231-
end
232229
return constructorof(N)(1, false, nothing, 0, op, l)
233230
end
234231
function (::Type{N})(
235232
op::Integer, l::AbstractExpressionNode{T1}, r::AbstractExpressionNode{T2}
236233
) where {T1,T2,N<:AbstractExpressionNode}
237234
@assert l isa N && r isa N
238-
if !(N isa UnionAll)
239-
@warn "Ignoring specified type parameters in binary operator constructor."
240-
end
241235
# Get highest type:
242236
if T1 != T2
243237
T = promote_type(T1, T2)

test/test_tree_construction.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,32 @@ for unaop in [cos, exp, safe_log, safe_log2, safe_log10, safe_sqrt, relu, gamma,
9090
end
9191
end
9292

93-
# We also test whether we can set a node equal to another node:
94-
operators = OperatorEnum(; default_params...)
95-
tree = Node(Float64; feature=1)
96-
tree2 = exp(Node(Float64; feature=2) / 3.2) + Node(Float64; feature=1) * 2.0
97-
98-
# Test printing works:
99-
io = IOBuffer()
100-
print(io, tree2)
101-
s = String(take!(io))
102-
@test s == "exp(x2 / 3.2) + (x1 * 2.0)"
103-
104-
set_node!(tree, tree2)
105-
@test tree !== tree2
106-
@test repr(tree) == repr(tree2)
93+
@testset "Set a node equal to another node" begin
94+
operators = OperatorEnum(; default_params...)
95+
tree = Node(Float64; feature=1)
96+
tree2 = exp(Node(Float64; feature=2) / 3.2) + Node(Float64; feature=1) * 2.0
97+
98+
# Test printing works:
99+
io = IOBuffer()
100+
print(io, tree2)
101+
s = String(take!(io))
102+
@test s == "exp(x2 / 3.2) + (x1 * 2.0)"
103+
104+
set_node!(tree, tree2)
105+
@test tree !== tree2
106+
@test repr(tree) == repr(tree2)
107+
end
108+
109+
@testset "Miscellaneous" begin
110+
operators = OperatorEnum(; default_params...)
111+
for N in (Node, GraphNode)
112+
tree = N{ComplexF64}(; val=1)
113+
@test typeof(tree.val) === ComplexF64
114+
115+
x = N{BigFloat}(; feature=1)
116+
@test_throws AssertionError N{Float32}(1, x)
117+
@test N{BigFloat}(1, x) == N(1, x)
118+
@test typeof(N(1, x, N{Float32}(; val=1))) === N{BigFloat}
119+
@test typeof(N(1, N{Float32}(; val=1), x)) === N{BigFloat}
120+
end
121+
end

0 commit comments

Comments
 (0)