From 06cabdf8e464d46f40813c8ac638f0a81e5b8722 Mon Sep 17 00:00:00 2001 From: Alban Gossard Date: Thu, 17 Jul 2025 15:36:30 +0200 Subject: [PATCH 1/3] add repack/canonicalize in vec_pjac! to support SciMLStructs --- src/gauss_adjoint.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 2c176d85f..639f346f9 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -495,14 +495,15 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) vtmp4 .= λ Enzyme.remake_zero!(tmp3) Enzyme.remake_zero!(out) - + + dp = isscimlstructure(p) ? repack(out) : out if SciMLBase.isinplace(sol.prob.f) Enzyme.remake_zero!(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) else function g(du, u, p, t) du .= f(u, p, t) @@ -512,7 +513,10 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(g, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t)) + Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) + end + if isscimlstructure(p) + out .+= canonicalize(Tunable(), dp)[1] end elseif sensealg.autojacvec isa MooncakeVJP _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) From af830bef9070865172cc7264b8250cc0aaf51563 Mon Sep 17 00:00:00 2001 From: Alban Gossard Date: Thu, 17 Jul 2025 22:16:38 +0200 Subject: [PATCH 2/3] fix typo in vec_pjac! of GaussAdjoint for SciMLStructures --- src/gauss_adjoint.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 639f346f9..5310a4f91 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -516,7 +516,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) Enzyme.Const(y), Enzyme.Duplicated(p, dp), Enzyme.Const(t)) end if isscimlstructure(p) - out .+= canonicalize(Tunable(), dp)[1] + out .= canonicalize(Tunable(), dp)[1] end elseif sensealg.autojacvec isa MooncakeVJP _, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ) From 02a867408ba2730cba3fb27bb75f139fbd52cae5 Mon Sep 17 00:00:00 2001 From: Alban Gossard Date: Thu, 17 Jul 2025 22:47:43 +0200 Subject: [PATCH 3/3] add test of GaussAdjoint with EnzymeVJP and SciMLStructs --- test/scimlstructures_interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/scimlstructures_interface.jl b/test/scimlstructures_interface.jl index 303fa7a9c..043495077 100644 --- a/test/scimlstructures_interface.jl +++ b/test/scimlstructures_interface.jl @@ -159,3 +159,4 @@ end run_diff(initialize()) @test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint())[1].ps) @test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec = false))[1].ps) +@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec = EnzymeVJP()))[1].ps)