Skip to content

Commit b3a5fdc

Browse files
committed
Fix turbo macro for isfinite checks
1 parent 224a567 commit b3a5fdc

File tree

3 files changed

+47
-18
lines changed

3 files changed

+47
-18
lines changed

src/EvaluateEquation.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ function deg1_l2_ll0_lr0_eval(
213213
cumulator = Array{T,1}(undef, n)
214214
@maybe_turbo turbo for j in indices((cX, cumulator), (2, 1))
215215
x_l = op_l(val_ll, cX[feature_lr, j])::T
216-
x = isfinite(x_l) ? op(x_l)::T : T(Inf) # These will get discovered by _eval_tree_array at end.
216+
x = isfinite(x_l) ? op(x_l)::T : T(Inf)
217217
cumulator[j] = x
218218
end
219219
return (cumulator, true)
@@ -273,6 +273,7 @@ function deg1_l1_ll0_eval(
273273
end
274274
end
275275

276+
# op(x, y) for x and y variable/constant
276277
function deg2_l0_r0_eval(
277278
tree::Node{T},
278279
cX::AbstractMatrix{T},
@@ -320,6 +321,7 @@ function deg2_l0_r0_eval(
320321
return (cumulator, true)
321322
end
322323

324+
# op(x, y) for x variable/constant, y arbitrary
323325
function deg2_l0_eval(
324326
tree::Node{T},
325327
cX::AbstractMatrix{T},
@@ -349,6 +351,7 @@ function deg2_l0_eval(
349351
return (cumulator, true)
350352
end
351353

354+
# op(x, y) for x arbitrary, y variable/constant
352355
function deg2_r0_eval(
353356
tree::Node{T},
354357
cX::AbstractMatrix{T},
@@ -520,9 +523,9 @@ function eval(current_node)
520523
function eval_tree_array(
521524
tree::Node, cX::AbstractArray, operators::GenericOperatorEnum; throw_errors::Bool=true
522525
)
523-
!throw_errors && return _eval_tree_array(tree, cX, operators, Val(false))
526+
!throw_errors && return _eval_tree_array_generic(tree, cX, operators, Val(false))
524527
try
525-
return _eval_tree_array(tree, cX, operators, Val(true))
528+
return _eval_tree_array_generic(tree, cX, operators, Val(true))
526529
catch e
527530
tree_s = string_tree(tree, operators)
528531
error_msg = "Failed to evaluate tree $(tree_s)."
@@ -537,7 +540,7 @@ function eval_tree_array(
537540
end
538541
end
539542

540-
function _eval_tree_array(
543+
function _eval_tree_array_generic(
541544
tree::Node{T1},
542545
cX::AbstractArray{T2,N},
543546
operators::GenericOperatorEnum,
@@ -554,13 +557,13 @@ function _eval_tree_array(
554557
end
555558
end
556559
elseif tree.degree == 1
557-
return deg1_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
560+
return deg1_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
558561
else
559-
return deg2_eval(tree, cX, vals[tree.op], operators, Val(throw_errors))
562+
return deg2_eval_generic(tree, cX, vals[tree.op], operators, Val(throw_errors))
560563
end
561564
end
562565

563-
function deg1_eval(
566+
function deg1_eval_generic(
564567
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
565568
) where {op_idx,throw_errors}
566569
left, complete = eval_tree_array(tree.l, cX, operators)
@@ -570,7 +573,7 @@ function deg1_eval(
570573
return op(left), true
571574
end
572575

573-
function deg2_eval(
576+
function deg2_eval_generic(
574577
tree, cX, ::Val{op_idx}, operators::GenericOperatorEnum, ::Val{throw_errors}
575578
) where {op_idx,throw_errors}
576579
left, complete = eval_tree_array(tree.l, cX, operators)

src/Utils.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,22 @@ function _remove_type_assertions(ex::Expr)
1212
return Expr(ex.head, map(_remove_type_assertions, ex.args)...)
1313
end
1414
end
15+
_remove_type_assertions(ex) = ex
1516

16-
function _remove_type_assertions(ex)
17-
return ex
17+
"""Replace instances of (isfinite(x) ? op(x) : T(Inf)) with op(x)"""
18+
function _remove_isfinite(ex::Expr)
19+
if (
20+
ex.head == :if &&
21+
length(ex.args) == 3 &&
22+
ex.args[1].head == :call &&
23+
ex.args[1].args[1] == :isfinite
24+
)
25+
return _remove_isfinite(ex.args[2])
26+
else
27+
return Expr(ex.head, map(_remove_isfinite, ex.args)...)
28+
end
1829
end
30+
_remove_isfinite(ex) = ex
1931

2032
"""
2133
@maybe_turbo use_turbo expression
@@ -26,6 +38,7 @@ This will also remove all type assertions from the expression.
2638
macro maybe_turbo(turboflag, ex)
2739
# Thanks @jlapeyre https://discourse.julialang.org/t/optional-macro-invocation/18588
2840
clean_ex = _remove_type_assertions(ex)
41+
clean_ex = _remove_isfinite(clean_ex)
2942
turbo_ex = Expr(:macrocall, Symbol("@turbo"), LineNumberNode(@__LINE__), clean_ex)
3043
simple_ex = Expr(
3144
:macrocall,

test/test_evaluation.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ using Test
44
include("test_params.jl")
55

66
# Test simple evaluations:
7-
operators = OperatorEnum(;
8-
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
9-
)
10-
117
functions = [
128
# deg2_l0_r0_eval
139
(x1, x2, x3) -> x1 * x2,
@@ -39,7 +35,9 @@ functions = [
3935
(x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0,
4036
]
4137

42-
for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
38+
for turbo in [false, true],
39+
T in [Float16, Float32, Float64],
40+
(i_func, fnc) in enumerate(functions)
4341

4442
# Float16 not implemented:
4543
turbo && T == Float16 && continue
@@ -53,7 +51,11 @@ for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
5351
nodefnc = fnc
5452
end
5553

56-
local tree, X
54+
local tree, operators, X
55+
operators = OperatorEnum(;
56+
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
57+
)
58+
@extend_operators operators
5759
tree = nodefnc(Node("x1"), Node("x2"), Node("x3"))
5860
tree = convert(Node{T}, tree)
5961

@@ -65,14 +67,25 @@ for turbo in [false, true], T in [Float16, Float32, Float64], fnc in functions
6567
true_y = realfnc.(X[1, :], X[2, :], X[3, :])
6668

6769
zero_tolerance = (T == Float16 ? 1e-4 : 1e-6)
68-
@test all(abs.(test_y .- true_y) / N .< zero_tolerance)
70+
try
71+
@test all(abs.(test_y .- true_y) / N .< zero_tolerance)
72+
catch
73+
println("Test for type $T and turbo=$turbo and function $i_func $tree failed.")
74+
mse = sum((x,) -> x^2, test_y .- true_y) / N
75+
mean = sum(test_y) / N
76+
stdev = sqrt(sum((x,) -> x^2, true_y .- mean) / N)
77+
println("Relative error: $(mse / stdev)")
78+
end
6979
end
7080

7181
for turbo in [false, true], T in [Float16, Float32, Float64]
7282
turbo && T == Float16 && continue
7383
# Test specific branches of evaluation code:
7484
# op(op(<constant>))
75-
local tree
85+
local tree, operators
86+
operators = OperatorEnum(;
87+
default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
88+
)
7689
tree = Node(1, Node(1, Node(; val=3.0f0)))
7790
@test repr(tree) == "cos(cos(3.0))"
7891
tree = convert(Node{T}, tree)

0 commit comments

Comments
 (0)