@@ -94,34 +94,43 @@ This holds options for expression evaluation, such as evaluation backend.
94
94
- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
95
95
This should be an instance of `ArrayBuffer` which has an `array` field and an
96
96
`index` field used to iterate which buffer slot to use.
97
+ - `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
98
+ evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
99
+ you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
100
+ evaluation.
97
101
"""
98
- struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} }
102
+ struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} ,U }
99
103
turbo:: Val{T}
100
104
bumper:: Val{B}
101
105
early_exit:: Val{E}
102
106
buffer:: BUF
107
+ use_fused:: Val{U}
103
108
end
104
109
105
110
@unstable function EvalOptions (;
106
111
turbo:: Union{Bool,Val} = Val (false ),
107
112
bumper:: Union{Bool,Val} = Val (false ),
108
113
early_exit:: Union{Bool,Val} = Val (true ),
109
114
buffer:: Union{ArrayBuffer,Nothing} = nothing ,
115
+ use_fused:: Union{Bool,Val} = Val (true ),
110
116
)
111
117
v_turbo = _to_bool_val (turbo)
112
118
v_bumper = _to_bool_val (bumper)
113
119
v_early_exit = _to_bool_val (early_exit)
120
+ v_use_fused = _to_bool_val (use_fused)
114
121
115
122
if v_bumper isa Val{true }
116
123
@assert buffer === nothing
117
124
end
118
125
119
- return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer)
126
+ return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer, v_use_fused )
120
127
end
121
128
122
129
@unstable @inline _to_bool_val (x:: Bool ) = x ? Val (true ) : Val (false )
123
130
@inline _to_bool_val (:: Val{T} ) where {T} = Val (T:: Bool )
124
131
132
+ @inline use_fused (eval_options:: EvalOptions ) = eval_options. use_fused isa Val{true }
133
+
125
134
_copy (x) = copy (x)
126
135
_copy (:: Nothing ) = nothing
127
136
function Base. copy (eval_options:: EvalOptions )
@@ -130,6 +139,7 @@ function Base.copy(eval_options::EvalOptions)
130
139
bumper= eval_options. bumper,
131
140
early_exit= eval_options. early_exit,
132
141
buffer= _copy (eval_options. buffer),
142
+ use_fused= eval_options. use_fused,
133
143
)
134
144
end
135
145
@@ -433,19 +443,20 @@ end
433
443
end
434
444
end
435
445
return quote
446
+ fused = use_fused (eval_options)
436
447
return Base. Cartesian. @nif (
437
448
$ nbin,
438
449
i -> i == op_idx, # COV_EXCL_LINE
439
450
i -> let op = operators. binops[i] # COV_EXCL_LINE
440
- if get_child (tree, 1 ). degree == 0 && get_child (tree, 2 ). degree == 0
451
+ if fused && get_child (tree, 1 ). degree == 0 && get_child (tree, 2 ). degree == 0
441
452
deg2_l0_r0_eval (tree, cX, op, eval_options)
442
- elseif get_child (tree, 2 ). degree == 0
453
+ elseif fused && get_child (tree, 2 ). degree == 0
443
454
result_l = _eval_tree_array (get_child (tree, 1 ), cX, operators, eval_options)
444
455
! result_l. ok && return result_l
445
456
@return_on_nonfinite_array (eval_options, result_l. x)
446
457
# op(x, y), where y is a constant or variable but x is not.
447
458
deg2_r0_eval (tree, result_l. x, cX, op, eval_options)
448
- elseif get_child (tree, 1 ). degree == 0
459
+ elseif fused && get_child (tree, 1 ). degree == 0
449
460
result_r = _eval_tree_array (get_child (tree, 2 ), cX, operators, eval_options)
450
461
! result_r. ok && return result_r
451
462
@return_on_nonfinite_array (eval_options, result_r. x)
@@ -487,19 +498,22 @@ end
487
498
# This @nif lets us generate an if statement over choice of operator,
488
499
# which means the compiler will be able to completely avoid type inference on operators.
489
500
return quote
501
+ fused = use_fused (eval_options)
490
502
Base. Cartesian. @nif (
491
503
$ nuna,
492
504
i -> i == op_idx, # COV_EXCL_LINE
493
505
i -> let op = operators. unaops[i] # COV_EXCL_LINE
494
- if get_child (tree, 1 ). degree == 2 &&
506
+ if fused &&
507
+ get_child (tree, 1 ). degree == 2 &&
495
508
get_child (get_child (tree, 1 ), 1 ). degree == 0 &&
496
509
get_child (get_child (tree, 1 ), 2 ). degree == 0
497
510
# op(op2(x, y)), where x, y, z are constants or variables.
498
511
l_op_idx = get_child (tree, 1 ). op
499
512
dispatch_deg1_l2_ll0_lr0_eval (
500
513
tree, cX, op, l_op_idx, operators. binops, eval_options
501
514
)
502
- elseif get_child (tree, 1 ). degree == 1 &&
515
+ elseif fused &&
516
+ get_child (tree, 1 ). degree == 1 &&
503
517
get_child (get_child (tree, 1 ), 1 ). degree == 0
504
518
# op(op2(x)), where x is a constant or variable.
505
519
l_op_idx = get_child (tree, 1 ). op
0 commit comments