Skip to content

Commit 34271cb

Browse files
authored
Merge pull request #11 from SymbolicML/snoopcompile
Improve precompilation with SnoopCompile.jl
2 parents 57bbea0 + 0434c92 commit 34271cb

File tree

8 files changed

+269
-77
lines changed

8 files changed

+269
-77
lines changed

.github/workflows/CI.yml

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,25 @@ jobs:
2727
fail-fast: false
2828
matrix:
2929
julia-version:
30-
- '1.8.2'
30+
- '1.8.3'
3131
os:
3232
- ubuntu-latest
33+
- windows-latest
34+
- macOS-latest
35+
include:
36+
- os: ubuntu-latest
37+
julia-version: '1.7.2'
38+
- os: ubuntu-latest
39+
julia-version: '1.6.7'
3340

3441
steps:
35-
- uses: actions/checkout@v1.0.0
42+
- uses: actions/checkout@v2
3643
- name: "Set up Julia"
37-
uses: julia-actions/setup-julia@v1.6.0
44+
uses: julia-actions/setup-julia@v1
3845
with:
3946
version: ${{ matrix.julia-version }}
40-
- name: Cache dependencies
41-
uses: actions/cache@v1 # Thanks FromFile.jl
42-
env:
43-
cache-name: cache-artifacts
44-
with:
45-
path: ~/.julia/artifacts
46-
key: ${{ runner.os }}-build-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
47-
restore-keys: |
48-
${{ runner.os }}-build-${{ env.cache-name }}-
49-
${{ runner.os }}-build-
50-
${{ runner.os }}-
47+
- name: "Cache artifacts"
48+
uses: julia-actions/cache@v1
5149
- name: "Build package"
5250
uses: julia-actions/julia-buildpkg@v1
5351
- name: "Run tests"

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
12+
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
1213
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1314
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1415
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1516

1617
[compat]
1718
LoopVectorization = "0.12"
1819
Reexport = "1"
20+
SnoopPrecompile = "1"
1921
SymbolicUtils = "0.19"
2022
Zygote = "0.6"
2123
julia = "1.6"

src/DynamicExpressions.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,7 @@ macro ignore(args...) end
4343
# To get LanguageServer to register library within tests
4444
@ignore include("../test/runtests.jl")
4545

46+
include("precompile.jl")
47+
do_precompilation(; mode=:precompile)
48+
4649
end

src/OperatorEnumConstruction.jl

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -221,51 +221,20 @@ function OperatorEnum(;
221221
binary_operators = Function[op for op in binary_operators]
222222
unary_operators = Function[op for op in unary_operators]
223223

224-
if enable_autodiff
225-
diff_binary_operators = Function[]
226-
diff_unary_operators = Function[]
224+
diff_binary_operators = Function[]
225+
diff_unary_operators = Function[]
227226

228-
test_inputs = Float32.(LinRange(-100, 100, 99))
229-
# Create grid over [-100, 100]^2:
230-
test_inputs_xy = Array{Float32}(undef, 2, 99^2)
231-
row = 1
232-
for x in test_inputs, y in test_inputs
233-
test_inputs_xy[:, row] .= [x, y]
234-
row += 1
235-
end
227+
if enable_autodiff
236228
for op in binary_operators
237229
diff_op(x, y) = gradient(op, x, y)
238-
239-
test_output = diff_op.(test_inputs_xy[1, :], test_inputs_xy[2, :])
240-
gradient_exists = all((x) -> x !== nothing, Iterators.flatten(test_output))
241-
if gradient_exists
242-
push!(diff_binary_operators, diff_op)
243-
else
244-
@warn "Automatic differentiation has been turned off, since operator $(op) does not have well-defined gradients."
245-
enable_autodiff = false
246-
break
247-
end
230+
push!(diff_binary_operators, diff_op)
248231
end
249-
250232
for op in unary_operators
251233
diff_op(x) = gradient(op, x)[1]
252-
test_output = diff_op.(test_inputs)
253-
gradient_exists = all((x) -> x !== nothing, test_output)
254-
if gradient_exists
255-
push!(diff_unary_operators, diff_op)
256-
else
257-
@warn "Automatic differentiation has been turned off, since operator $(op) does not have well-defined gradients."
258-
enable_autodiff = false
259-
break
260-
end
234+
push!(diff_unary_operators, diff_op)
261235
end
262236
end
263237

264-
if !enable_autodiff
265-
diff_binary_operators = Function[]
266-
diff_unary_operators = Function[]
267-
end
268-
269238
operators = OperatorEnum(
270239
binary_operators, unary_operators, diff_binary_operators, diff_unary_operators
271240
)

src/precompile.jl

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import SnoopPrecompile: @precompile_all_calls, @precompile_setup
2+
3+
macro ignore_domain_error(ex)
4+
return esc(
5+
quote
6+
try
7+
$ex
8+
catch e
9+
if !(e isa DomainError)
10+
rethrow(e)
11+
end
12+
end
13+
end,
14+
)
15+
end
16+
17+
"""
18+
test_all_combinations(; binary_operators, unary_operators, turbo, types)
19+
20+
Test all combinations of the given operators and types. Useful for precompilation.
21+
"""
22+
function test_all_combinations(; binary_operators, unary_operators, turbo, types)
23+
for binops in binary_operators,
24+
unaops in unary_operators,
25+
use_turbo in turbo,
26+
T in types
27+
28+
length(binops) == 0 && length(unaops) == 0 && continue
29+
T == Float16 && use_turbo && continue
30+
31+
X = rand(T, 3, 10)
32+
operators = OperatorEnum(;
33+
binary_operators=binops,
34+
unary_operators=unaops,
35+
define_helper_functions=false,
36+
enable_autodiff=true,
37+
)
38+
x = Node(T; feature=1)
39+
c = Node(T; val=one(T))
40+
41+
# Trivial:
42+
for l in (x, c)
43+
@ignore_domain_error eval_tree_array(l, X, operators; turbo=use_turbo)
44+
for variable in (true, false)
45+
@ignore_domain_error eval_grad_tree_array(
46+
l, X, operators; turbo=use_turbo, variable
47+
)
48+
end
49+
end
50+
51+
# Binary operators
52+
for i in eachindex(binops), l in (x, c), r in (x, c)
53+
tree = Node(i, l, r)
54+
tree = convert(Node{T}, tree)
55+
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
56+
for variable in (true, false)
57+
@ignore_domain_error eval_grad_tree_array(
58+
l, X, operators; turbo=use_turbo, variable
59+
)
60+
end
61+
end
62+
63+
# Unary operators
64+
for j in eachindex(unaops), k in eachindex(unaops), l in (x, c)
65+
tree = Node(j, l)
66+
tree = convert(Node{T}, tree)
67+
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
68+
for variable in (true, false)
69+
@ignore_domain_error eval_grad_tree_array(
70+
l, X, operators; turbo=use_turbo, variable
71+
)
72+
end
73+
74+
tree = Node(j, Node(k, l))
75+
tree = convert(Node{T}, tree)
76+
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
77+
for variable in (true, false)
78+
@ignore_domain_error eval_grad_tree_array(
79+
l, X, operators; turbo=use_turbo, variable
80+
)
81+
end
82+
end
83+
84+
# Both operators
85+
for i in eachindex(binary_operators),
86+
j1 in eachindex(unary_operators),
87+
j2 in eachindex(unary_operators),
88+
l in (x, c),
89+
r in (x, c)
90+
91+
tree = Node(i, Node(j1, l), Node(j2, r))
92+
tree = convert(Node{T}, tree)
93+
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
94+
for variable in (true, false)
95+
@ignore_domain_error eval_grad_tree_array(
96+
l, X, operators; turbo=use_turbo, variable
97+
)
98+
end
99+
100+
tree = Node(j1, Node(i, l, r))
101+
tree = convert(Node{T}, tree)
102+
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
103+
for variable in (true, false)
104+
@ignore_domain_error eval_grad_tree_array(
105+
l, X, operators; turbo=use_turbo, variable
106+
)
107+
end
108+
end
109+
end
110+
return nothing
111+
end
112+
113+
function test_functions_on_trees(::Type{T}, operators) where {T}
114+
local x, c, tree
115+
for T1 in [Float16, Float32, Float64]
116+
x = Node(T1; feature=1)
117+
c = Node(T1; val=T1(1.0))
118+
tree = Node(
119+
2,
120+
Node(1, Node(1, Node(2, x, c), Node(3, c, Node(1, x)))),
121+
Node(3, Node(1, Node(4, x, x))),
122+
)
123+
end
124+
tree = convert(Node{T}, tree)
125+
for preserve_topology in [true, false]
126+
tree = copy_node(tree; preserve_topology)
127+
set_node!(tree, copy_node(tree; preserve_topology))
128+
end
129+
130+
string_tree(tree, operators)
131+
count_nodes(tree)
132+
count_constants(tree)
133+
count_depth(tree)
134+
index_constants(tree)
135+
has_operators(tree)
136+
has_constants(tree)
137+
get_constants(tree)
138+
set_constants(tree, get_constants(tree))
139+
combine_operators(tree, operators)
140+
simplify_tree(tree, operators)
141+
return nothing
142+
end
143+
144+
macro maybe_precompile_setup(mode, ex)
145+
precompile_ex = Expr(
146+
:macrocall, Symbol("@precompile_setup"), LineNumberNode(@__LINE__), ex
147+
)
148+
return quote
149+
if $(esc(mode)) == :compile
150+
$(esc(ex))
151+
elseif $(esc(mode)) == :precompile
152+
$(esc(precompile_ex))
153+
else
154+
error("Invalid value for mode: " * show($(esc(mode))))
155+
end
156+
end
157+
end
158+
159+
macro maybe_precompile_all_calls(mode, ex)
160+
precompile_ex = Expr(
161+
:macrocall, Symbol("@precompile_all_calls"), LineNumberNode(@__LINE__), ex
162+
)
163+
return quote
164+
if $(esc(mode)) == :compile
165+
$(esc(ex))
166+
elseif $(esc(mode)) == :precompile
167+
$(esc(precompile_ex))
168+
else
169+
error("Invalid value for mode: " * show($(esc(mode))))
170+
end
171+
end
172+
end
173+
174+
"""`mode=:precompile` will use `@precompile_*` directives; `mode=:compile` runs."""
175+
function do_precompilation(; mode=:precompile)
176+
@maybe_precompile_setup mode begin
177+
binary_operators = [[+, -, *, /, ^]]
178+
unary_operators = [[sin, cos, exp, log, sqrt, abs]]
179+
turbo = [true, false]
180+
types = [Float32, Float64]
181+
@maybe_precompile_all_calls mode begin
182+
test_all_combinations(;
183+
binary_operators=binary_operators,
184+
unary_operators=unary_operators,
185+
turbo=turbo,
186+
types=types,
187+
)
188+
end
189+
operators = OperatorEnum(;
190+
binary_operators=binary_operators[1],
191+
unary_operators=unary_operators[1],
192+
define_helper_functions=false,
193+
)
194+
# Want to precompile all above calls.
195+
types = [Float16, Float32, Float64]
196+
for T in types
197+
@maybe_precompile_all_calls mode begin
198+
test_functions_on_trees(T, operators)
199+
end
200+
end
201+
end
202+
end

test/test_evaluation.jl

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,30 +111,40 @@ for turbo in [false, true], T in [Float16, Float32, Float64]
111111
@test isnan(tree(X; turbo=turbo)[1])
112112
end
113113

114-
# And, with generic operator enum, this should be an actual error:
115-
operators = GenericOperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
116-
x1 = Node(Float64; feature=1)
117-
tree = sin(x1 / 0.0)
118-
X = randn(Float32, 10);
119-
@noinline stack = try
120-
tree(X)[1]
121-
@test false
122-
catch e
123-
@test isa(e, ErrorException)
124-
# Check that "Failed to evaluate" is in the message:
125-
@test occursin("Failed to evaluate", e.msg)
126-
current_exceptions()
127-
end;
128-
@test length(stack) == 2
129-
@test isa(stack[1].exception, DomainError)
130-
131-
# If a method is not defined, we should get a nothing:
132-
X = randn(Float32, 1, 10);
133-
@test tree(X; throw_errors=false) === nothing
134-
# or a MethodError:
135-
try
136-
tree(X; throw_errors=true)
137-
@test false
138-
catch e
139-
@test isa(current_exceptions()[1].exception, MethodError)
114+
# Check if julia version >= 1.7:
115+
if VERSION >= v"1.7"
116+
# And, with generic operator enum, this should be an actual error:
117+
operators = GenericOperatorEnum(;
118+
binary_operators=[+, -, *, /], unary_operators=[cos, sin]
119+
)
120+
x1 = Node(Float64; feature=1)
121+
tree = sin(x1 / 0.0)
122+
X = randn(Float32, 10)
123+
local stack
124+
try
125+
tree(X)[1]
126+
@test false
127+
catch e
128+
@test e isa ErrorException
129+
# Check that "Failed to evaluate" is in the message:
130+
@test occursin("Failed to evaluate", e.msg)
131+
stack = current_exceptions()
132+
end
133+
@test length(stack) == 2
134+
@test stack[1].exception isa DomainError
135+
136+
# If a method is not defined, we should get a nothing:
137+
X = randn(Float32, 1, 10)
138+
@test tree(X; throw_errors=false) === nothing
139+
# or a MethodError:
140+
try
141+
tree(X; throw_errors=true)
142+
@test false
143+
catch e
144+
@test e isa ErrorException
145+
@test occursin("Failed to evaluate", e.msg)
146+
stack = current_exceptions()
147+
end
148+
@test length(stack) == 2
149+
@test stack[1].exception isa MethodError
140150
end

0 commit comments

Comments
 (0)