Skip to content

Commit e4fc02c

Browse files
authored
Merge pull request #1280 from FluxML/bc/passthrough-threading-ccall
Passthrough safe ccalls in threading code
2 parents 99d5a38 + 7d0376a commit e4fc02c

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/compiler/reverse.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ xaccum(ir) = nothing
254254
xaccum(ir, x) = x
255255
xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...))
256256

257+
function passthrough_expr(ex::Expr)
258+
# Metadata we want to preserve
259+
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true
260+
# ccalls and more that are safe to preserve/required for proper operation:
261+
# - jl_set_task_threadpoolid: added in 1.9 for @spawn
262+
isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true
263+
return false
264+
end
265+
257266
function adjoint(pr::Primal)
258267
ir, sigs = adjointcfg(pr)
259268
for b in reverse(blocks(pr.ir))
@@ -278,10 +287,9 @@ function adjoint(pr::Primal)
278287
end
279288
elseif ex isa Core.PiNode
280289
grads[ex.val] = grads[v]
281-
elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo)
282-
elseif isexpr(ex)
290+
elseif isexpr(ex) && !passthrough_expr(ex)
283291
push!(rb, stmt(xcall(Base, :error, """
284-
Can't differentiate $(ex.head) expression.
292+
Can't differentiate $(ex.head) expression $ex.
285293
You might want to check the Zygote limitations documentation.
286294
https://fluxml.ai/Zygote.jl/latest/limitations
287295
"""),

0 commit comments

Comments
 (0)