Skip to content

Commit e3aa296

Browse files
authored
Merge pull request #97 from SymbolicML/prevent-julia-bugs
Increase purity of `@generated` functions during tests with DispatchDoctor
2 parents 67bfab0 + 54acd34 commit e3aa296

File tree

8 files changed

+51
-54
lines changed

8 files changed

+51
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.19.0"
4+
version = "0.19.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ import .NodeModule:
6767
get_scalar_constants,
6868
set_scalar_constants!
6969
@reexport import .StringsModule: string_tree, print_tree
70+
import .StringsModule: get_op_name
7071
@reexport import .OperatorEnumModule: AbstractOperatorEnum
7172
@reexport import .OperatorEnumConstructionModule:
7273
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!

src/Evaluate.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ struct EvalOptions{T,B,E}
5252
early_exit::Val{E}
5353
end
5454

55-
@stable(
56-
default_mode = "disable",
57-
default_union_limit = 2,
58-
@inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
59-
)
55+
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
6056
@inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool)
6157

6258
@unstable function EvalOptions(;
@@ -74,11 +70,11 @@ end
7470
throw(ArgumentError("Invalid keyword argument(s): $(keys(deprecated_kws))"))
7571
end
7672
if !isempty(deprecated_kws)
77-
@assert eval_options === nothing "Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`."
78-
Base.depwarn(
79-
"The `turbo` and `bumper` keyword arguments are deprecated. Please use `eval_options` instead.",
80-
:eval_tree_array,
73+
@assert(
74+
eval_options === nothing,
75+
"Cannot use both `eval_options` and deprecated flags `turbo` and `bumper`."
8176
)
77+
# TODO: We don't do a depwarn as it can GREATLY bottleneck the search speed.
8278
end
8379
if eval_options !== nothing
8480
return eval_options
@@ -183,8 +179,10 @@ function eval_tree_array(
183179
return eval_tree_array(tree, cX, operators; kws...)
184180
end
185181

186-
get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
187-
get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)
182+
# These are marked unstable due to issues discussed on
183+
# https://github.com/JuliaLang/julia/issues/55147
184+
@unstable get_nuna(::Type{<:OperatorEnum{B,U}}) where {B,U} = counttuple(U)
185+
@unstable get_nbin(::Type{<:OperatorEnum{B}}) where {B} = counttuple(B)
188186

189187
function _eval_tree_array(
190188
tree::AbstractExpressionNode{T},

src/Strings.jl

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,6 @@ using ..UtilsModule: deprecate_varmap
44
using ..OperatorEnumModule: AbstractOperatorEnum
55
using ..NodeModule: AbstractExpressionNode, tree_mapreduce
66

7-
const OP_NAMES = Base.ImmutableDict(
8-
"safe_log" => "log",
9-
"safe_log2" => "log2",
10-
"safe_log10" => "log10",
11-
"safe_log1p" => "log1p",
12-
"safe_acosh" => "acosh",
13-
"safe_sqrt" => "sqrt",
14-
"safe_pow" => "^",
15-
)
16-
177
function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
188
if deg == 1
199
return vcat(collect("unary_operator["), collect(string(idx)), [']'])
@@ -23,34 +13,38 @@ function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
2313
end
2414
function dispatch_op_name(::Val{deg}, operators::AbstractOperatorEnum, idx) where {deg}
2515
if deg == 1
26-
return get_op_name(operators.unaops[idx])::Vector{Char}
16+
return collect(get_op_name(operators.unaops[idx])::String)
2717
else
28-
return get_op_name(operators.binops[idx])::Vector{Char}
18+
return collect(get_op_name(operators.binops[idx])::String)
2919
end
3020
end
3121

32-
@generated function get_op_name(op::F)::Vector{Char} where {F}
22+
const OP_NAME_CACHE = (; x=Dict{UInt64,String}(), lock=Threads.SpinLock())
23+
24+
function get_op_name(op::F) where {F}
25+
h = hash(op)
26+
lock(OP_NAME_CACHE.lock)
3327
try
34-
# Bit faster to just cache the name of the operator:
35-
op_s = if F <: Broadcast.BroadcastFunction
36-
string(F.parameters[1].instance) * '.'
37-
else
38-
string(F.instance)
28+
cache = OP_NAME_CACHE.x
29+
if haskey(cache, h)
30+
return cache[h]
3931
end
40-
if length(op_s) == 2 && op_s[1] in ('+', '-', '*', '/', '^') && op_s[2] == '.'
41-
op_s = '.' * op_s[1]
42-
end
43-
out = collect(get(OP_NAMES, op_s, op_s))
44-
return :($out)
45-
catch
46-
end
47-
return quote
48-
op_s = typeof(op) <: Broadcast.BroadcastFunction ? string(op.f) * '.' : string(op)
49-
if length(op_s) == 2 && op_s[1] in ('+', '-', '*', '/', '^') && op_s[2] == '.'
50-
op_s = '.' * op_s[1]
32+
op_s = if op isa Broadcast.BroadcastFunction
33+
base_op_s = string(op.f)
34+
if length(base_op_s) == 1 && first(base_op_s) in ('+', '-', '*', '/', '^')
35+
# Like `.+`
36+
string('.', base_op_s)
37+
else
38+
# Like `cos.`
39+
string(base_op_s, '.')
40+
end
41+
else
42+
string(op)
5143
end
52-
out = collect(get(OP_NAMES, op_s, op_s))
53-
return out
44+
cache[h] = op_s
45+
return op_s
46+
finally
47+
unlock(OP_NAME_CACHE.lock)
5448
end
5549
end
5650

src/Utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module UtilsModule
33

44
using MacroTools: postwalk, @capture, splitdef, combinedef
5+
using DispatchDoctor: @unstable
56

67
# Returns two arrays
78
macro return_on_false2(flag, retval, retval2)
@@ -124,7 +125,7 @@ function deprecate_varmap(variable_names, varMap, func_name)
124125
return variable_names
125126
end
126127

127-
counttuple(::Type{<:NTuple{N,Any}}) where {N} = N
128+
@unstable counttuple(::Type{<:NTuple{N,Any}}) where {N} = N
128129

129130
"""
130131
Undefined

test/test_deprecations.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,6 @@ if VERSION >= v"1.9"
4545
)
4646
end
4747

48-
# Old usage of evaluation options
49-
if VERSION >= v"1.9-"
50-
ex = Expression(Node{Float64}(; feature=1))
51-
@test_logs (:warn, r"The `turbo` and `bumper` keyword arguments are deprecated.*") (ex(
52-
randn(Float64, 1, 10), OperatorEnum(); turbo=true
53-
))
54-
end
55-
5648
# Test deprecated modules
5749
logs = @capture_err begin
5850
@eval using DynamicExpressions.EquationModule

test/test_params.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DynamicExpressions
2+
import DynamicExpressions as DE
23

34
maximum_residual = 1e-2
45

@@ -26,6 +27,13 @@ maximum_residual = 1e-2
2627
greater(x, y) = (x > y)
2728

2829
custom_cos(x) = cos(x)^2
30+
31+
DE.get_op_name(::typeof(safe_log)) = "log"
32+
DE.get_op_name(::typeof(safe_log2)) = "log2"
33+
DE.get_op_name(::typeof(safe_log10)) = "log10"
34+
DE.get_op_name(::typeof(safe_log1p)) = "log1p"
35+
DE.get_op_name(::typeof(safe_acosh)) = "acosh"
36+
DE.get_op_name(::typeof(safe_sqrt)) = "sqrt"
2937
end
3038

3139
HEADER_GUARD_TEST_PARAMS = true

test/test_print.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Test
22
using DynamicExpressions
3+
import DynamicExpressions as DE
34
import Compat: Returns
45

56
include("test_params.jl")
@@ -32,8 +33,10 @@ for unaop in [safe_log, safe_log2, safe_log10, safe_log1p, safe_sqrt, safe_acosh
3233
@test string_tree(minitree, opts) == replace(string(unaop), "safe_" => "") * "(x1)"
3334
end
3435

35-
!(@isdefined safe_pow) &&
36-
@eval safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y
36+
@isdefined(safe_pow) || @eval begin
37+
safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y
38+
DE.get_op_name(::typeof(safe_pow)) = "^"
39+
end
3740
for binop in [safe_pow, ^]
3841
opts = OperatorEnum(;
3942
default_params..., binary_operators=(+, *, /, -, binop), unary_operators=(cos,)

0 commit comments

Comments
 (0)