diff --git a/ext/ReferenceFrameRotationsZygoteExt.jl b/ext/ReferenceFrameRotationsZygoteExt.jl index 6356e16..d63e5d9 100644 --- a/ext/ReferenceFrameRotationsZygoteExt.jl +++ b/ext/ReferenceFrameRotationsZygoteExt.jl @@ -10,13 +10,14 @@ using ReferenceFrameRotations using ForwardDiff using Zygote.ChainRulesCore: ChainRulesCore -import Zygote.ChainRulesCore: NoTangent +import Zygote.ChainRulesCore: NoTangent, unthunk function ChainRulesCore.rrule(::Type{<:DCM}, data::NTuple{9, T}) where {T} y = DCM(data) function DCM_pullback(Δ) - return (NoTangent(), Tuple(Δ)) + Δ_unthunked = unthunk(Δ) + return (NoTangent(), Tuple(Δ_unthunked)) end return y, DCM_pullback @@ -26,8 +27,9 @@ function ChainRulesCore.rrule(::typeof(orthonormalize), dcm::DCM) y = orthonormalize(dcm) function orthonormalize_pullback(Δ) + Δ_unthunked = unthunk(Δ) jac = ForwardDiff.jacobian(orthonormalize, dcm) - return (NoTangent(), reshape(vcat(Δ...)' * jac, 3, 3)) + return (NoTangent(), reshape(vcat(Δ_unthunked...)' * jac, 3, 3)) end return y, orthonormalize_pullback