|
1 | | -function svd_pushforward!(dA, A, USVᴴ, dUSVᴴ; |
| 1 | +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; |
2 | 2 | tol::Real = default_pullback_gaugetol(USVᴴ[2]), |
3 | 3 | rank_atol::Real = tol, |
4 | 4 | degeneracy_atol::Real = tol, |
5 | 5 | gauge_atol::Real = tol |
6 | 6 | ) |
7 | | - U, S, Vᴴ = USVᴴ |
8 | | - dU, dS, dVᴴ = dUSVᴴ |
9 | | - V = adjoint(Vᴴ) |
10 | | - UdAV = U' * dA * V |
11 | | - copyto!(diagview(dS), diag(real.(UdAV))) |
12 | | - m, n = size(A) |
13 | | - F = one(eltype(S)) ./ (diagview(S)' .- diagview(S)) |
14 | | - G = one(eltype(S)) ./ (diagview(S)' .+ diagview(S)) |
| 7 | + U, Smat, Vᴴ = USVᴴ |
| 8 | + m, n = size(U, 1), size(Vᴴ, 2) |
| 9 | + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) |
| 10 | + minmn = min(m, n) |
| 11 | + S = diagview(Smat) |
| 12 | + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ |
| 13 | + r = searchsortedlast(S, rank_atol; rev = true) # rank |
| 14 | + |
| 15 | + vΔU = view(ΔU, :, 1:r) |
| 16 | + vΔS = view(ΔS, 1:r, 1:r) |
| 17 | + vΔVᴴ = view(ΔVᴴ, 1:r, :) |
| 18 | + |
| 19 | + vU = view(U, :, 1:r) |
| 20 | + vS = view(S, 1:r) |
| 21 | + vSmat = view(Smat, 1:r, 1:r) |
| 22 | + vVᴴ = view(Vᴴ, 1:r, :) |
| 23 | + |
| 24 | + # compact region |
| 25 | + vV = adjoint(vVᴴ) |
| 26 | + UΔAV = vU' * ΔA * vV |
| 27 | + copyto!(diagview(vΔS), diag(real.(UΔAV))) |
| 28 | + F = one(eltype(S)) ./ (transpose(vS) .- vS) |
| 29 | + G = one(eltype(S)) ./ (transpose(vS) .+ vS) |
15 | 30 | diagview(F) .= zero(eltype(F)) |
16 | | - invSdiag = zeros(eltype(S), length(diagview(S))) |
17 | | - for i in 1:length(diagview(S)) |
18 | | - @inbounds invSdiag[i] = inv(diagview(S)[i]) |
| 31 | + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 |
| 32 | + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 |
| 33 | + K̇ = hUΔAV + aUΔAV |
| 34 | + Ṁ = hUΔAV - aUΔAV |
| 35 | + |
| 36 | + # check gauge condition |
| 37 | + @assert isantihermitian(K̇) |
| 38 | + @assert isantihermitian(Ṁ) |
| 39 | + K̇diag = diagview(K̇) |
| 40 | + for i in 1:length(K̇diag) |
| 41 | + @assert K̇diag[i] ≈ (im/2) * imag(diagview(UΔAV)[i])/S[i] |
19 | 42 | end |
20 | | - invS = Diagonal(invSdiag) |
21 | | - #∂U = U * (F .* (U' * dA * V * S + S * Vᴴ * dA' * U)) + (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS |
22 | | - #∂V = V * (F .* (S * U' * dA * V + Vᴴ * dA' * U * S)) + (LinearAlgebra.diagm(ones(eltype(V), n)) - V*Vᴴ) * dA' * U * invS |
23 | | - hUdAV = F .* project_hermitian(UdAV) |
24 | | - aUdAV = G .* project_antihermitian(UdAV) |
25 | | - ∂U = U * (hUdAV + aUdAV) |
26 | | - ∂U += (LinearAlgebra.diagm(ones(eltype(U), m)) - U*U') * dA * V * invS |
27 | | - ∂V = V * (hUdAV - aUdAV) |
28 | | - ∂V += (LinearAlgebra.diagm(ones(eltype(U), n)) - V*V') * dA' * U * invS |
29 | | - copyto!(dU, ∂U) |
30 | | - adjoint!(dVᴴ, ∂V) |
31 | | - dA .= zero(eltype(A)) |
32 | | - return (dU, dS, dVᴴ) |
| 43 | + |
| 44 | + ∂U = vU * K̇ |
| 45 | + ∂V = vV * Ṁ |
| 46 | + # full component |
| 47 | + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn |
| 48 | + Uperp = view(U, :, minmn+1:m) |
| 49 | + Vᴴperp = view(Vᴴ, minmn+1:n, :) |
| 50 | + |
| 51 | + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) |
| 52 | + |
| 53 | + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) |
| 54 | + fill!(UÃÃV, 0) |
| 55 | + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV |
| 56 | + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' |
| 57 | + rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U) |
| 58 | + superKM = -sylvester(UÃÃV, Smat, rhs) |
| 59 | + K̇perp = view(superKM, 1:size(aUAV, 2)) |
| 60 | + Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2)) |
| 61 | + ∂U .+= Uperp * K̇perp |
| 62 | + ∂V .+= Vperp * Ṁperp |
| 63 | + else |
| 64 | + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU') |
| 65 | + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ) |
| 66 | + upper = ImUU * ΔA * vV |
| 67 | + lower = ImVV * ΔA' * vU |
| 68 | + rhs = vcat(upper, lower) |
| 69 | + |
| 70 | + Ã = ImUU * A * ImVV |
| 71 | + ÃÃ = similar(A, (m + n, m + n)) |
| 72 | + fill!(ÃÃ, 0) |
| 73 | + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã |
| 74 | + view(ÃÃ, m .+ (1:n), 1:m ) .= Ã' |
| 75 | + |
| 76 | + superLN = -sylvester(ÃÃ, vSmat, rhs) |
| 77 | + ∂U += view(superLN, 1:size(upper, 1), :) |
| 78 | + ∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :) |
| 79 | + end |
| 80 | + copyto!(vΔU, ∂U) |
| 81 | + adjoint!(vΔVᴴ, ∂V) |
| 82 | + return (ΔU, ΔS, ΔVᴴ) |
33 | 83 | end |
0 commit comments