Skip to content

Commit 891ae98

Browse files
committed
Some pushforwards progress
1 parent d64d2c5 commit 891ae98

File tree

8 files changed

+153
-196
lines changed

8 files changed

+153
-196
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ChainRulesTestUtils = "1"
3232
CUDA = "5"
3333
GenericLinearAlgebra = "0.3.19"
3434
GenericSchur = "0.5.6"
35-
Enzyme = "0.13.77"
35+
Enzyme = "0.13.96"
3636
EnzymeTestUtils = "0.2.3"
3737
JET = "0.9, 0.10"
3838
LinearAlgebra = "1"

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!
55
using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pushforward!, lq_pushforward!
66
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pushforward!, lq_null_pushforward!
77
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pushforward!, eigh_pushforward!
8+
using MatrixAlgebraKit: svd_pushforward!
89
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward!
910
using Enzyme
1011
using Enzyme.EnzymeCore
@@ -179,23 +180,7 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
179180
) where {RT}
180181
ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
181182
shadow = if EnzymeRules.needs_shadow(config)
182-
U, S, Vᴴ = ret
183-
V = adjoint(Vᴴ)
184-
∂S = Diagonal(diag(real.(U' * A.dval * V)))
185-
m, n = size(A.val)
186-
F = one(eltype(S)) ./ ((diagview(S).^2)' .- (diagview(S) .^ 2))
187-
diagview(F) .= zero(eltype(F))
188-
invSdiag = zeros(eltype(S), length(S.diag))
189-
for i in 1:length(S.diag)
190-
@inbounds invSdiag[i] = inv(diagview(S)[i])
191-
end
192-
invS = Diagonal(invSdiag)
193-
∂U = U * (F .* (U' * A.dval * V * S + S * Vᴴ * A.dval' * U)) + (diagm(ones(eltype(U), m)) - U*U') * A.dval * V * invS
194-
#∂Vᴴ = (FSdS' * Vᴴ) + (invS * U' * A.dval * (diagm(ones(eltype(U), size(V, 2))) - Vᴴ*V))
195-
∂V = V * (F .* (S * U' * A.dval * V + Vᴴ * A.dval' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * A.dval' * U * invS
196-
∂Vᴴ = similar(Vᴴ)
197-
adjoint!(∂Vᴴ, ∂V)
198-
(∂U, ∂S, ∂Vᴴ)
183+
svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
199184
else
200185
nothing
201186
end
@@ -221,46 +206,7 @@ function EnzymeRules.forward(config::EnzymeRules.FwdConfig,
221206
) where {RT}
222207
ret = EnzymeRules.needs_primal(config) || EnzymeRules.needs_shadow(config) ? func.val(A.val, USVᴴ.val; kwargs...) : nothing
223208
shadow = if EnzymeRules.needs_shadow(config)
224-
fatU, fatS, fatVᴴ = ret
225-
∂Ufat = zeros(eltype(fatU), size(fatU))
226-
∂Sfat = zeros(eltype(fatS), size(fatS))
227-
∂Vᴴfat = zeros(eltype(fatVᴴ), size(fatVᴴ))
228-
m, n = size(A.val)
229-
minmn = min(m, n)
230-
#U = view(fatU, :, 1:minmn)
231-
#S = Diagonal(diagview(fatS))
232-
#Vᴴ = view(fatVᴴ, 1:minmn, :)
233-
U = fatU
234-
S = fatS
235-
Vᴴ = fatVᴴ
236-
V = adjoint(Vᴴ)
237-
∂S = Diagonal(diag(real.(U' * A.dval * V)))
238-
diagview(∂Sfat) .= diagview(∂S)
239-
m, n = size(A.val)
240-
F = one(eltype(S)) ./ ((diagview(S).^2)' .- (diagview(S) .^ 2))
241-
diagview(F) .= zero(eltype(F))
242-
invSdiag = zeros(eltype(S), size(S))
243-
for ix in diagind(S)
244-
@inbounds invSdiag[ix] = inv(S[ix])
245-
end
246-
invS = invSdiag
247-
#FSdS = F .* (∂S * S .+ S * ∂S)
248-
∂U = U * (F .* (U' * A.dval * V * S + S * Vᴴ * A.dval' * U)) + (diagm(ones(eltype(U), m)) - U*U') * A.dval * V * invS
249-
#view(∂Ufat, :, 1:minmn) .= view(∂U, :, :)
250-
∂Ufat .= ∂U
251-
252-
253-
#∂Vᴴ = (FSdS' * Vᴴ) + (invS * U' * A.dval * (diagm(ones(eltype(U), size(V, 2))) - Vᴴ*V))
254-
∂V = V * (F .* (S * U' * A.dval * V + Vᴴ * A.dval' * U * S)) + (diagm(ones(eltype(V), n)) - V*Vᴴ) * A.dval' * U * invS
255-
∂Vᴴ = similar(Vᴴ)
256-
adjoint!(∂Vᴴ, ∂V)
257-
#view(∂Vᴴfat, 1:minmn, :) .= view(∂Vᴴ, :, :)
258-
∂Vᴴfat .= ∂Vᴴ
259-
#=view(∂Ufat, :, minmn+1:m) .= zero(eltype(fatU))
260-
view(∂Vᴴfat, minmn+1:n, :) .= zero(eltype(fatVᴴ))
261-
view(∂Sfat, minmn+1:m, :) .= zero(eltype(fatVᴴ))
262-
view(∂Sfat, :, minmn+1:n) .= zero(eltype(fatVᴴ))=#
263-
(∂Ufat, ∂Sfat, ∂Vᴴfat)
209+
svd_pushforward!(A.dval, A.val, ret, USVᴴ.dval)
264210
else
265211
nothing
266212
end

src/pushforwards/eig.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
function eig_pushforward!(dA, A, DV, dDV; kwargs...)
1+
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
22
D, V = DV
3-
dD, dV = dDV
4-
∂K = inv(V) * dA * V
5-
∂Kdiag = diagview(∂K)
6-
dD.diag .= ∂Kdiag
7-
∂K ./= transpose(diagview(D)) .- diagview(D)
8-
fill!(∂Kdiag, zero(eltype(D)))
9-
mul!(dV, V, ∂K, 1, 0)
10-
dA .= zero(eltype(dA))
11-
return dDV
3+
ΔD, ΔV = ΔDV
4+
iVΔAV = inv(V) * ΔA * V
5+
diagview(ΔD) .= diagview(iVΔAV)
6+
F = 1 ./ (transpose(diagview(D)) .- diagview(D))
7+
fill!(diagview(F), zero(eltype(F)))
8+
= F .* iVΔAV
9+
mul!(ΔV, V, , 1, 0)
10+
zero!(ΔA)
11+
return ΔDV
1212
end

src/pushforwards/eigh.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
2+
D, V = DV
3+
dD, dV = dDV
24
tmpV = V \ dA
35
∂K = tmpV * V
46
∂Kdiag = diag(∂K)

src/pushforwards/polar.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
1515
ΔP, ΔWᴴ = ΔPWᴴ
1616
dAW = ΔA * adjoint(Wᴴ)
1717
= sylvester(P, P, -(dAW - adjoint(dAW)))
18-
ImW = (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
19-
@show size(P), size(ΔA), size(ImW), size(Wᴴ)
20-
= inv(P)*ΔA*ImW
18+
= inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
2119
ΔWᴴ .=* Wᴴ +
2220
ΔP .= dAW - P *
2321
MatrixAlgebraKit.zero!(ΔA)

src/pushforwards/qr.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,22 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pull
3838
dR11 .= Rtmp * R11
3939
dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
4040
dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
41-
dQ2 .= -Q1 * (Q1' * dQ2)
4241
if size(Q2, 2) > 0
42+
dQ2 .= -Q1 * (Q1' * Q2)
4343
dQ2 .+= Q2 * (Q2' * dQ2)
4444
end
4545
if m3 > 0 && size(Q, 2) > minmn
4646
# only present for qr_full or rank-deficient qr_compact
4747
Q′ = view(Q, :, 1:minmn)
48-
println("minmn $minmn m $m")
4948
Q3 = view(Q, :, minmn+1:m)
5049
#dQ3 .= Q′ * (Q′' * Q3)
5150
dQ3 .= Q3
5251
end
53-
#=if !isempty(dR22)
52+
if !isempty(dR22)
5453
_, r22 = qr_full(dA2 - dQ1*R12 - Q1*dR12, MatrixAlgebraKit.LAPACK_HouseholderQR(; positive=true))
5554
dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
56-
end=#
55+
end
5756
return (dQ, dR)
5857
end
5958

60-
function qr_null_pushforward!(dA, A, QR, dQR; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) end
59+
function qr_null_pushforward!(dA, A, N, dN; tol::Real=MatrixAlgebraKit.default_pullback_gaugetol(N), rank_atol::Real=tol, gauge_atol::Real=tol) end

src/pushforwards/svd.jl

Lines changed: 75 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,83 @@
1-
function svd_pushforward!(dA, A, USVᴴ, dUSVᴴ;
1+
function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ;
22
tol::Real = default_pullback_gaugetol(USVᴴ[2]),
33
rank_atol::Real = tol,
44
degeneracy_atol::Real = tol,
55
gauge_atol::Real = tol
66
)
7-
U, S, Vᴴ = USVᴴ
8-
dU, dS, dVᴴ = dUSVᴴ
9-
V = adjoint(Vᴴ)
10-
UdAV = U' * dA * V
11-
copyto!(diagview(dS), diag(real.(UdAV)))
12-
m, n = size(A)
13-
F = one(eltype(S)) ./ (diagview(S)' .- diagview(S))
14-
G = one(eltype(S)) ./ (diagview(S)' .+ diagview(S))
7+
U, Smat, Vᴴ = USVᴴ
8+
m, n = size(U, 1), size(Vᴴ, 2)
9+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
10+
minmn = min(m, n)
11+
S = diagview(Smat)
12+
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
13+
r = searchsortedlast(S, rank_atol; rev = true) # rank
14+
15+
vΔU = view(ΔU, :, 1:r)
16+
vΔS = view(ΔS, 1:r, 1:r)
17+
vΔVᴴ = view(ΔVᴴ, 1:r, :)
18+
19+
vU = view(U, :, 1:r)
20+
vS = view(S, 1:r)
21+
vSmat = view(Smat, 1:r, 1:r)
22+
vVᴴ = view(Vᴴ, 1:r, :)
23+
24+
# compact region
25+
vV = adjoint(vVᴴ)
26+
UΔAV = vU' * ΔA * vV
27+
copyto!(diagview(vΔS), diag(real.(UΔAV)))
28+
F = one(eltype(S)) ./ (transpose(vS) .- vS)
29+
G = one(eltype(S)) ./ (transpose(vS) .+ vS)
1530
diagview(F) .= zero(eltype(F))
16-
invSdiag = zeros(eltype(S), length(diagview(S)))
17-
for i in 1:length(diagview(S))
18-
@inbounds invSdiag[i] = inv(diagview(S)[i])
31+
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
32+
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
33+
= hUΔAV + aUΔAV
34+
= hUΔAV - aUΔAV
35+
36+
# check gauge condition
37+
@assert isantihermitian(K̇)
38+
@assert isantihermitian(Ṁ)
39+
K̇diag = diagview(K̇)
40+
for i in 1:length(K̇diag)
41+
@assert K̇diag[i] (im/2) * imag(diagview(UΔAV)[i])/S[i]
1942
end
20-
invS = Diagonal(invSdiag)
21-
#∂U = U * (F .* (U' * dA * V * S + S * Vᴴ * dA' * U)) + (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS
22-
#∂V = V * (F .* (S * U' * dA * V + Vᴴ * dA' * U * S)) + (LinearAlgebra.diagm(ones(eltype(V), n)) - V*Vᴴ) * dA' * U * invS
23-
hUdAV = F .* project_hermitian(UdAV)
24-
aUdAV = G .* project_antihermitian(UdAV)
25-
∂U = U * (hUdAV + aUdAV)
26-
∂U += (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS
27-
∂V = V * (hUdAV - aUdAV)
28-
∂V += (LinearAlgebra.diagm(ones(eltype(U), n)) - V*V') * dA' * U * invS
29-
copyto!(dU, ∂U)
30-
adjoint!(dVᴴ, ∂V)
31-
dA .= zero(eltype(A))
32-
return (dU, dS, dVᴴ)
43+
44+
∂U = vU *
45+
∂V = vV *
46+
# full component
47+
if size(U, 2) > minmn && size(Vᴴ, 1) > minmn
48+
Uperp = view(U, :, minmn+1:m)
49+
Vᴴperp = view(Vᴴ, minmn+1:n, :)
50+
51+
aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)
52+
53+
UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
54+
fill!(UÃÃV, 0)
55+
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
56+
view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
57+
rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U)
58+
superKM = -sylvester(UÃÃV, Smat, rhs)
59+
K̇perp = view(superKM, 1:size(aUAV, 2))
60+
Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2))
61+
∂U .+= Uperp * K̇perp
62+
∂V .+= Vperp * Ṁperp
63+
else
64+
ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU')
65+
ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ)
66+
upper = ImUU * ΔA * vV
67+
lower = ImVV * ΔA' * vU
68+
rhs = vcat(upper, lower)
69+
70+
= ImUU * A * ImVV
71+
ÃÃ = similar(A, (m + n, m + n))
72+
fill!(ÃÃ, 0)
73+
view(ÃÃ, (1:m), m .+ (1:n)) .=
74+
view(ÃÃ, m .+ (1:n), 1:m ) .='
75+
76+
superLN = -sylvester(ÃÃ, vSmat, rhs)
77+
∂U += view(superLN, 1:size(upper, 1), :)
78+
∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :)
79+
end
80+
copyto!(vΔU, ∂U)
81+
adjoint!(vΔVᴴ, ∂V)
82+
return (ΔU, ΔS, ΔVᴴ)
3383
end

0 commit comments

Comments
 (0)