Skip to content

Commit b114eb0

Browse files
authored
Merge pull request #128 from SymbolicML/skip-fused-kernels
feat: option to skip fused kernels
2 parents b705e46 + ecb6117 commit b114eb0

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ import .StringsModule: get_op_name, get_pretty_op_name
8181
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
8282
@reexport import .EvaluateModule:
8383
eval_tree_array, differentiable_eval_tree_array, EvalOptions
84-
import .EvaluateModule: ArrayBuffer
84+
import .EvaluateModule: ArrayBuffer, ResultOk
8585
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
8686
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
8787
@reexport import .SimplifyModule: combine_operators, simplify_tree!

src/Evaluate.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,34 +94,43 @@ This holds options for expression evaluation, such as evaluation backend.
9494
- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
9595
This should be an instance of `ArrayBuffer` which has an `array` field and an
9696
`index` field used to iterate which buffer slot to use.
97+
- `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
98+
evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
99+
you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
100+
evaluation.
97101
"""
98-
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
102+
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing},U}
99103
turbo::Val{T}
100104
bumper::Val{B}
101105
early_exit::Val{E}
102106
buffer::BUF
107+
use_fused::Val{U}
103108
end
104109

105110
@unstable function EvalOptions(;
106111
turbo::Union{Bool,Val}=Val(false),
107112
bumper::Union{Bool,Val}=Val(false),
108113
early_exit::Union{Bool,Val}=Val(true),
109114
buffer::Union{ArrayBuffer,Nothing}=nothing,
115+
use_fused::Union{Bool,Val}=Val(true),
110116
)
111117
v_turbo = _to_bool_val(turbo)
112118
v_bumper = _to_bool_val(bumper)
113119
v_early_exit = _to_bool_val(early_exit)
120+
v_use_fused = _to_bool_val(use_fused)
114121

115122
if v_bumper isa Val{true}
116123
@assert buffer === nothing
117124
end
118125

119-
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer)
126+
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, v_use_fused)
120127
end
121128

122129
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
123130
@inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool)
124131

132+
@inline use_fused(eval_options::EvalOptions) = eval_options.use_fused isa Val{true}
133+
125134
_copy(x) = copy(x)
126135
_copy(::Nothing) = nothing
127136
function Base.copy(eval_options::EvalOptions)
@@ -130,6 +139,7 @@ function Base.copy(eval_options::EvalOptions)
130139
bumper=eval_options.bumper,
131140
early_exit=eval_options.early_exit,
132141
buffer=_copy(eval_options.buffer),
142+
use_fused=eval_options.use_fused,
133143
)
134144
end
135145

@@ -433,19 +443,20 @@ end
433443
end
434444
end
435445
return quote
446+
fused = use_fused(eval_options)
436447
return Base.Cartesian.@nif(
437448
$nbin,
438449
i -> i == op_idx, # COV_EXCL_LINE
439450
i -> let op = operators.binops[i] # COV_EXCL_LINE
440-
if get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0
451+
if fused && get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0
441452
deg2_l0_r0_eval(tree, cX, op, eval_options)
442-
elseif get_child(tree, 2).degree == 0
453+
elseif fused && get_child(tree, 2).degree == 0
443454
result_l = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options)
444455
!result_l.ok && return result_l
445456
@return_on_nonfinite_array(eval_options, result_l.x)
446457
# op(x, y), where y is a constant or variable but x is not.
447458
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
448-
elseif get_child(tree, 1).degree == 0
459+
elseif fused && get_child(tree, 1).degree == 0
449460
result_r = _eval_tree_array(get_child(tree, 2), cX, operators, eval_options)
450461
!result_r.ok && return result_r
451462
@return_on_nonfinite_array(eval_options, result_r.x)
@@ -487,19 +498,22 @@ end
487498
# This @nif lets us generate an if statement over choice of operator,
488499
# which means the compiler will be able to completely avoid type inference on operators.
489500
return quote
501+
fused = use_fused(eval_options)
490502
Base.Cartesian.@nif(
491503
$nuna,
492504
i -> i == op_idx, # COV_EXCL_LINE
493505
i -> let op = operators.unaops[i] # COV_EXCL_LINE
494-
if get_child(tree, 1).degree == 2 &&
506+
if fused &&
507+
get_child(tree, 1).degree == 2 &&
495508
get_child(get_child(tree, 1), 1).degree == 0 &&
496509
get_child(get_child(tree, 1), 2).degree == 0
497510
# op(op2(x, y)), where x, y, z are constants or variables.
498511
l_op_idx = get_child(tree, 1).op
499512
dispatch_deg1_l2_ll0_lr0_eval(
500513
tree, cX, op, l_op_idx, operators.binops, eval_options
501514
)
502-
elseif get_child(tree, 1).degree == 1 &&
515+
elseif fused &&
516+
get_child(tree, 1).degree == 1 &&
503517
get_child(get_child(tree, 1), 1).degree == 0
504518
# op(op2(x)), where x is a constant or variable.
505519
l_op_idx = get_child(tree, 1).op

0 commit comments

Comments
 (0)