Skip to content

Conversation

@ChrisRackauckas-Claude
Copy link

Fixes #1282

Summary

Previously, GaussAdjoint with ZygoteVJP() would fail when used with in-place ODE functions, throwing a MethodError: no method matching fiip(::Vector{Float64}, ::Vector{Float64}, ::Float64) because it was calling the in-place function f(du, u, p, t) with only 3 arguments as if it were out-of-place f(u, p, t).

Changes

  • Added check for in-place vs out-of-place ODE functions using SciMLBase.isinplace()
  • For in-place functions: Uses Zygote.Buffer to allow mutation during forward pass while remaining differentiable
  • For out-of-place functions: Keeps existing behavior
  • Added comprehensive tests for both in-place and out-of-place cases

Technical Details

The key insight is using Zygote.Buffer which enables Zygote to differentiate through in-place functions by:

  1. Allowing controlled mutation during the forward pass
  2. Returning an immutable copy via copy(du_buf) for the backward pass
  3. Letting Zygote build the pullback for the values, not the mutation operations

Testing

  • ✅ New test file test/gauss_zygote_inplace.jl with comprehensive coverage
  • ✅ Tests both in-place and out-of-place ODE functions
  • ✅ Verified with the original failing example from issue Adjoint fails when using ZygoteVJP #1282
  • ✅ Tests pass on Julia 1.10.10 (LTS) and 1.11.7

Test Output

Julia version: 1.11.7
✓ Basic solve with GaussAdjoint succeeded
Computing gradients...
✓ Gradient computation succeeded!
  du0 = [-44.39290603871477, -8.472174656702613]
  dp = [0.1944861128441424, -158.5656423260541, 75.24591831397161, -345.0325681954895]

🤖 Generated with Claude Code

Fixes SciML#1282

Previously, `GaussAdjoint` with `ZygoteVJP()` would fail when used with
in-place ODE functions, throwing a `MethodError` because it was calling
the in-place function `f(du, u, p, t)` with only 3 arguments as if it
were out-of-place `f(u, p, t)`.

This fix:
- Checks if the ODE function is in-place using `SciMLBase.isinplace()`
- For in-place functions, creates a `Zygote.Buffer` to allow mutation
  during the forward pass while remaining differentiable
- For out-of-place functions, keeps the existing behavior

The use of `Zygote.Buffer` enables Zygote to differentiate through
in-place functions by allowing controlled mutation during the forward
pass and returning an immutable copy for the backward pass.

Added comprehensive tests for both in-place and out-of-place ODE
functions with `GaussAdjoint(autojacvec = ZygoteVJP())`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Adjoint fails when using ZygoteVJP

2 participants