Skip to content

Commit f07d4d5

Browse files
committed
Fix diagview and format
1 parent 9091c56 commit f07d4d5

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,11 @@ for (f!, f_full!, pb!) in (
359359
) where {RT}
360360
cache_A = nothing
361361
cache_D = nothing
362-
nD, V = MatrixAlgebraKit.initialize_output($f_full!, A.val, alg.val)
363-
nD, V = $f_full!(A.val, (nD, V), alg.val)
362+
nD, V = MatrixAlgebraKit.initialize_output($f_full!, A.val, alg.val)
363+
nD, V = $f_full!(A.val, (nD, V), alg.val)
364364
copy!(D.val, diagview(nD))
365-
primal = EnzymeRules.needs_primal(config) ? D.val : nothing
366-
shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
365+
primal = EnzymeRules.needs_primal(config) ? D.val : nothing
366+
shadow = EnzymeRules.needs_shadow(config) ? D.dval : nothing
367367
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_D, V))
368368
end
369369
function EnzymeRules.reverse(
@@ -379,9 +379,9 @@ for (f!, f_full!, pb!) in (
379379
cache_A, cache_D, V = cache
380380
Dval = !isnothing(cache_D) ? cache_D : D.val
381381
Aval = !isnothing(cache_A) ? cache_A : A.val
382-
∂D = isa(D, Const) ? nothing : D.dval
382+
∂D = isa(D, Const) ? nothing : D.dval
383383
if !isa(A, Const) && !isa(D, Const)
384-
$pb!(A.dval, Aval, (Dval, V), (∂D, nothing))
384+
$pb!(A.dval, Aval, (Diagonal(Dval), V), (Diagonal(∂D), nothing))
385385
end
386386
!isa(D, Const) && make_zero!(D.dval)
387387
return (nothing, nothing, nothing)

0 commit comments

Comments
 (0)