From 3eb58c5b35f38c5a01bad96e553810d681f37eb5 Mon Sep 17 00:00:00 2001 From: Jordan Murphy Date: Fri, 1 Aug 2025 21:57:26 +0300 Subject: [PATCH] fix; modify zygote ext to handle thunks --- ext/ReferenceFrameRotationsZygoteExt.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ext/ReferenceFrameRotationsZygoteExt.jl b/ext/ReferenceFrameRotationsZygoteExt.jl index 6356e16..d63e5d9 100644 --- a/ext/ReferenceFrameRotationsZygoteExt.jl +++ b/ext/ReferenceFrameRotationsZygoteExt.jl @@ -10,13 +10,14 @@ using ReferenceFrameRotations using ForwardDiff using Zygote.ChainRulesCore: ChainRulesCore -import Zygote.ChainRulesCore: NoTangent +import Zygote.ChainRulesCore: NoTangent, unthunk function ChainRulesCore.rrule(::Type{<:DCM}, data::NTuple{9, T}) where {T} y = DCM(data) function DCM_pullback(Δ) - return (NoTangent(), Tuple(Δ)) + Δ_unthunked = unthunk(Δ) + return (NoTangent(), Tuple(Δ_unthunked)) end return y, DCM_pullback @@ -26,8 +27,9 @@ function ChainRulesCore.rrule(::typeof(orthonormalize), dcm::DCM) y = orthonormalize(dcm) function orthonormalize_pullback(Δ) + Δ_unthunked = unthunk(Δ) jac = ForwardDiff.jacobian(orthonormalize, dcm) - return (NoTangent(), reshape(vcat(Δ...)' * jac, 3, 3)) + return (NoTangent(), reshape(vcat(Δ_unthunked...)' * jac, 3, 3)) end return y, orthonormalize_pullback