Skip to content

Commit d64d2c5

Browse files
committed
Working frule for polar
1 parent 78cfd67 commit d64d2c5

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
216216
vS = Diagonal(diagview(S)[1:minmn])
217217
vVᴴ = view(Vᴴ, 1:minmn, :)
218218
vdU = view(dU, :, 1:minmn)
219-
vdS = Diagonal(diagview(dS)[1:minmn])
219+
vdS = view(dS, 1:minmn, 1:minmn)
220220
vdVᴴ = view(dVᴴ, 1:minmn, :)
221221
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
222222
end
@@ -243,7 +243,18 @@ for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
243243
U, dU = arrayify(U_, dU_)
244244
S, dS = arrayify(S_, dS_)
245245
Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
246-
(dU, dS, dVᴴ) = svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
246+
minmn = min(size(A)...)
247+
if ($f == svd_compact!) # compact
248+
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
249+
else # full
250+
vU = view(U, :, 1:minmn)
251+
vS = view(S, 1:minmn, 1:minmn)
252+
vVᴴ = view(Vᴴ, 1:minmn, :)
253+
vdU = view(dU, :, 1:minmn)
254+
vdS = view(dS, 1:minmn, 1:minmn)
255+
vdVᴴ = view(dVᴴ, 1:minmn, :)
256+
svd_pushforward!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ); kwargs...)
257+
end
247258
return USVᴴ_dUSVᴴ
248259
end
249260
end

src/pushforwards/polar.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,25 @@
1-
function left_polar_pushforward! end
2-
function right_polar_pushforward! end
1+
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
2+
W, P = WP
3+
ΔW, ΔP = ΔWP
4+
aWdA = adjoint(W) * ΔA
5+
= sylvester(P, P, -(aWdA - adjoint(aWdA)))
6+
= (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
7+
ΔW .= W *+
8+
ΔP .= aWdA -*P
9+
MatrixAlgebraKit.zero!(ΔA)
10+
return (ΔW, ΔP)
11+
end
12+
13+
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
14+
P, Wᴴ = PWᴴ
15+
ΔP, ΔWᴴ = ΔPWᴴ
16+
dAW = ΔA * adjoint(Wᴴ)
17+
= sylvester(P, P, -(dAW - adjoint(dAW)))
18+
ImW = (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
19+
@show size(P), size(ΔA), size(ImW), size(Wᴴ)
20+
= inv(P)*ΔA*ImW
21+
ΔWᴴ .=* Wᴴ +
22+
ΔP .= dAW - P *
23+
MatrixAlgebraKit.zero!(ΔA)
24+
return (ΔWᴴ, ΔP)
25+
end

0 commit comments

Comments
 (0)