@@ -254,6 +254,15 @@ xaccum(ir) = nothing
254254xaccum (ir, x) = x
255255xaccum (ir, xs... ) = push! (ir, xcall (Zygote, :accum , xs... ))
256256
257+ function passthrough_expr (ex:: Expr )
258+ # Metadata we want to preserve
259+ isexpr (ex, GlobalRef, :call , :isdefined , :inbounds , :meta , :loopinfo ) && return true
260+ # ccalls and more that are safe to preserve/required for proper operation:
261+ # - jl_set_task_threadpoolid: added in 1.9 for @spawn
262+ isexpr (ex, :foreigncall ) && unwrapquote (ex. args[1 ]) in (:jl_set_task_threadpoolid ,) && return true
263+ return false
264+ end
265+
257266function adjoint (pr:: Primal )
258267 ir, sigs = adjointcfg (pr)
259268 for b in reverse (blocks (pr. ir))
@@ -278,10 +287,9 @@ function adjoint(pr::Primal)
278287 end
279288 elseif ex isa Core. PiNode
280289 grads[ex. val] = grads[v]
281- elseif isexpr (ex, GlobalRef, :call , :isdefined , :inbounds , :meta , :loopinfo )
282- elseif isexpr (ex)
290+ elseif isexpr (ex) && ! passthrough_expr (ex)
283291 push! (rb, stmt (xcall (Base, :error , """
284- Can't differentiate $(ex. head) expression.
292+ Can't differentiate $(ex. head) expression $ex .
285293 You might want to check the Zygote limitations documentation.
286294 https://fluxml.ai/Zygote.jl/latest/limitations
287295 """ ),
0 commit comments