From 42e3ae04c8dbab0d06a89dcf603d44cf154f9b09 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Apr 2025 23:13:11 -0400 Subject: [PATCH 1/3] fix vec(y) in withjacobian --- src/lib/grad.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 92be71d34..cd3ca6e47 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -178,14 +178,16 @@ julia> withjacobian(cumsum, [1,2,3]) ``` """ function withjacobian(f, args...) - y, back = pullback(_jvec∘f, args...) + y, back1 = pullback(f, args...) + yvec, back2 = pullback(_jvec, y) + back = dy -> back1(back2(dy)[1]) out = map(args) do x T = promote_type(eltype(x), eltype(y)) dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) : x isa Number ? similar(y, T, length(y)) : nothing end - delta = _eyelike(y) + delta = _eyelike(yvec) for k in LinearIndices(y) grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) From c38b0196ffc66f3b8bc626f51695105cd7040585 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Apr 2025 23:19:50 -0400 Subject: [PATCH 2/3] add a test --- test/utils_tests.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 691455491..1341d9993 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -77,6 +77,11 @@ end j6 = jacobian((x,y) -> abs2.(x .* y), [1+im, 2], 3+4im) @test j6[1][1,:] ≈ g6[1] @test j6[2][1] ≈ g6[2] + + # https://github.com/FluxML/Zygote.jl/issues/1506 + y7, g7 = Zygote.withjacobian(identity, rand(2, 3)); + @test size(y7) == (2,3) + @test only(g7) == I end @testset "jacobian(loss, ::Params)" begin From fb43e5c303188d31be0c4607799c511d456eadaf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Apr 2025 23:55:03 -0400 Subject: [PATCH 3/3] fix --- src/lib/grad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index cd3ca6e47..50670bf43 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -188,7 +188,7 @@ function withjacobian(f, args...) nothing end delta = _eyelike(yvec) - for k in LinearIndices(y) + for k in LinearIndices(yvec) grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) dx isa AbstractArray || continue