Skip to content

Commit 22379f6

Browse files
committed
Add dummy reverse modes for truncated methods
1 parent 9698614 commit 22379f6

File tree

2 files changed

+86
-40
lines changed

2 files changed

+86
-40
lines changed

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module MatrixAlgebraKitEnzymeExt
22

33
using MatrixAlgebraKit
4-
using MatrixAlgebraKit: diagview, inv_safe
4+
using MatrixAlgebraKit: diagview, inv_safe, eig_trunc!, eigh_trunc!
55
using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd!
66
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd!
77
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd!
@@ -321,7 +321,7 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
321321
A::Annotation{<:AbstractMatrix},
322322
USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
323323
ϵ::Annotation{Vector{T}},
324-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
324+
alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
325325
kwargs...,
326326
) where {RT, T<:Real}
327327
# form cache if needed
@@ -350,7 +350,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
350350
A::Annotation{<:AbstractMatrix},
351351
USVᴴ::Annotation{<:Tuple{<:AbstractMatrix, <:AbstractMatrix, <:AbstractMatrix}},
352352
ϵ::Annotation{Vector{T}},
353-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
353+
alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
354354
kwargs...) where {RT, T<:Real}
355355
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
356356
U, S, Vᴴ = cache_USVᴴ
@@ -363,7 +363,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
363363
make_zero!(USVᴴ.dval)
364364
end
365365
if !isa(ϵ, Const)
366-
ϵ.dval .= zero(T)
366+
make_zero!(ϵ.dval)
367367
end
368368
return (nothing, nothing, nothing, nothing)
369369
end
@@ -476,16 +476,17 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
476476
::Type{RT},
477477
A::Annotation{<:AbstractMatrix},
478478
DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
479-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
479+
ϵ::Annotation{Vector{T}},
480+
alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
480481
kwargs...,
481-
) where {RT}
482+
) where {RT, T}
482483
# form cache if needed
483484
cache_A = copy(A.val)
484-
eigh_full!(A.val, DV.val, alg.val.alg)
485+
MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
485486
cache_DV = copy.(DV.val)
486487
DV′, ind = MatrixAlgebraKit.truncate(eigh_trunc!, DV.val, alg.val.trunc)
487-
ϵ = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
488-
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ) : nothing
488+
ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
489+
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
489490
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
490491
dD, dV = DV.dval
491492
dDtrunc = Diagonal(diagview(dD)[ind])
@@ -494,17 +495,18 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
494495
else
495496
(nothing, nothing)
496497
end
497-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(ϵ)) : nothing
498+
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
498499
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
499500
end
500501
function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
501502
func::Const{typeof(eigh_trunc!)},
502-
dret,
503+
::Type{RT},
503504
cache,
504505
A::Annotation{<:AbstractMatrix},
505506
DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
507+
ϵ::Annotation{Vector{T}},
506508
alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
507-
kwargs...)
509+
kwargs...) where {RT, T}
508510
cache_A, cache_DV, cache_dDVtrunc, ind = cache
509511
D, V = cache_DV
510512
dD, dV = cache_dDVtrunc
@@ -515,24 +517,28 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
515517
if !isa(DV, Const)
516518
make_zero!(DV.dval)
517519
end
520+
if !isa(ϵ, Const)
521+
make_zero!.dval)
522+
end
518523
return (nothing, nothing, nothing, nothing)
519524
end
520-
#=
525+
521526
function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
522527
func::Const{typeof(eig_trunc!)},
523528
::Type{RT},
524529
A::Annotation{<:AbstractMatrix},
525530
DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
526-
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm};
531+
ϵ::Annotation{Vector{T}},
532+
alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
527533
kwargs...,
528-
) where {RT}
534+
) where {RT, T}
529535
# form cache if needed
530536
cache_A = copy(A.val)
531537
eig_full!(A.val, DV.val, alg.val.alg)
532538
cache_DV = copy.(DV.val)
533539
DV′, ind = MatrixAlgebraKit.truncate(eig_trunc!, DV.val, alg.val.trunc)
534-
ϵ = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
535-
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ) : nothing
540+
ϵ.val[1] = MatrixAlgebraKit.truncation_error!(diagview(DV.val[1]), ind)
541+
primal = EnzymeRules.needs_primal(config) ? (DV′..., ϵ.val) : nothing
536542
shadow_DV = if !isa(A, Const) && !isa(DV, Const)
537543
dD, dV = DV.dval
538544
dDtrunc = Diagonal(diagview(dD)[ind])
@@ -541,30 +547,34 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
541547
else
542548
(nothing, nothing)
543549
end
544-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., zero(ϵ)) : nothing
550+
shadow = EnzymeRules.needs_shadow(config) ? (shadow_DV..., [zero(T)]) : nothing
545551
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_DV, shadow_DV, ind))
546552
end
547553
function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
548554
func::Const{typeof(eig_trunc!)},
549-
dret,
555+
::Type{RT},
550556
cache,
551557
A::Annotation{<:AbstractMatrix},
552558
DV::Annotation{<:Tuple{<:Diagonal, <:AbstractMatrix}},
559+
ϵ::Annotation{Vector{T}},
553560
alg::Const{<:MatrixAlgebraKit.TruncatedAlgorithm};
554-
kwargs...)
561+
kwargs...) where {RT, T}
555562
cache_A, cache_DV, cache_dDVtrunc, ind = cache
556563
D, V = cache_DV
557564
dD, dV = cache_dDVtrunc
558565
if !isa(A, Const) && !isa(DV, Const)
559566
A.dval .= zero(eltype(A.val))
560-
A.dval .= MatrixAlgebraKit.eigh_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...)
567+
A.dval .= MatrixAlgebraKit.eig_pullback!(A.dval, A.val, (D, V), (dD, dV), ind; kwargs...)
561568
end
562569
if !isa(DV, Const)
563570
make_zero!(DV.dval)
564571
end
572+
if !isa(ϵ, Const)
573+
make_zero!.dval)
574+
end
565575
return (nothing, nothing, nothing, nothing)
566576
end
567-
=#
577+
568578
function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1},
569579
func::Const{typeof(eigh_full!)},
570580
::Type{RT},

test/enzyme.jl

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,19 @@ end
113113
end
114114
end
115115

116+
function MatrixAlgebraKit.eig_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T}
117+
D, V = eig_full!(A, DV, alg.alg)
118+
DVtrunc, ind = MatrixAlgebraKit.truncate(eig_trunc!, (D, V), alg.trunc)
119+
ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(D), ind)
120+
return DVtrunc..., ϵ
121+
end
122+
function dummy_eig_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T}
123+
Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.eig_trunc, A)
124+
DV = MatrixAlgebraKit.initialize_output(eig_trunc!, A, alg)
125+
Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit.eig_trunc!(Ac, DV, ϵ, alg)
126+
return Dtrunc, Vtrunc, ϵ
127+
end
128+
116129
@timedtestset "EIG AD Rules with eltype $T" for T in ETs
117130
rng = StableRNG(12345)
118131
m = 19
@@ -129,16 +142,16 @@ end
129142
test_reverse(eig_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV)))
130143
test_reverse(eig_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag))
131144
end
132-
@testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
145+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
133146
for r in 1:4:m
134147
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
135148
ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc)
136149
Dtrunc = Diagonal(diagview(D)[ind])
137150
Vtrunc = V[:, ind]
138151
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
139152
ΔVtrunc = ΔV[:, ind]
140-
# broken right now due to Enzyme
141-
#test_reverse(eig_trunc!, RT, (A, TA), ((D, V), TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=Base.RefValue((ΔDtrunc, ΔVtrunc, zero(real(T)))))
153+
ϵ = [zero(real(T))]
154+
test_reverse(dummy_eig_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))]))
142155
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
143156
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
144157
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -149,8 +162,8 @@ end
149162
Vtrunc = V[:, ind]
150163
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
151164
ΔVtrunc = ΔV[:, ind]
152-
# broken right now due to Enzyme
153-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=Base.RefValue((ΔDtrunc, ΔVtrunc, zero(real(T)))))
165+
ϵ = [zero(real(T))]
166+
test_reverse(dummy_eig_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))]))
154167
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
155168
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
156169
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -168,10 +181,18 @@ function copy_eigh_vals(A; kwargs...)
168181
eigh_vals(A; kwargs...)
169182
end
170183

171-
function copy_eigh_trunc!(A; kwargs...)
184+
function MatrixAlgebraKit.eigh_trunc!(A, DV, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T}
185+
D, V = eigh_full!(A, DV, alg.alg)
186+
DVtrunc, ind = MatrixAlgebraKit.truncate(eigh_trunc!, (D, V), alg.trunc)
187+
ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(D), ind)
188+
return DVtrunc..., ϵ
189+
end
190+
function dummy_eigh_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T}
172191
A = (A + A')/2
173-
DV = MatrixAlgebraKit.initialize_output(eigh_trunc!, A, kwargs[:alg])
174-
eigh_trunc!(A, DV; kwargs...)
192+
Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.eigh_trunc, A)
193+
DV = MatrixAlgebraKit.initialize_output(eigh_trunc!, A, alg)
194+
Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit.eigh_trunc!(Ac, DV, ϵ, alg)
195+
return Dtrunc, Vtrunc, ϵ
175196
end
176197

177198
@timedtestset "EIGH AD Rules with eltype $T" for T in ETs
@@ -194,11 +215,11 @@ end
194215
@testset "forward: RT $RT, TA $TA" for RT in (Const, Duplicated,), TA in (Const, Duplicated,)
195216
test_forward(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol)
196217
end
197-
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
218+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
198219
test_reverse(copy_eigh_full, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔD2), copy(ΔV)))
199220
test_reverse(copy_eigh_vals, RT, (copy(A), TA); fkwargs=(alg=alg,), atol=atol, rtol=rtol, output_tangent=copy(ΔD2.diag))
200221
end
201-
@testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
222+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
202223
for r in 1:4:m
203224
Ddiag = diagview(D)
204225
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
@@ -207,8 +228,8 @@ end
207228
Vtrunc = V[:, ind]
208229
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
209230
ΔVtrunc = ΔV[:, ind]
210-
# broken right now due to Enzyme
211-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔDtrunc), copy(ΔVtrunc), zero(real(T))))
231+
ϵ = [zero(real(T))]
232+
test_reverse(dummy_eigh_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))]))
212233
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
213234
dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
214235
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -220,15 +241,28 @@ end
220241
Vtrunc = V[:, ind]
221242
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
222243
ΔVtrunc = ΔV[:, ind]
223-
# broken right now due to Enzyme
224-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(copy(ΔDtrunc), copy(ΔVtrunc), zero(real(T))))
244+
ϵ = [zero(real(T))]
245+
test_reverse(dummy_eigh_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔDtrunc, ΔVtrunc, [zero(real(T))]))
225246
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
226247
dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
227248
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
228249
end
229250
end
230251
end
231252

253+
function MatrixAlgebraKit.svd_trunc!(A, USVᴴ, ϵ::Vector{T}, alg::MatrixAlgebraKit.TruncatedAlgorithm) where {T}
254+
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
255+
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
256+
ϵ[1] = MatrixAlgebraKit.truncation_error!(diagview(S), ind)
257+
return USVᴴtrunc..., ϵ
258+
end
259+
function dummy_svd_trunc(A, ϵ::Vector{T}, alg::TruncatedAlgorithm) where {T}
260+
Ac = MatrixAlgebraKit.copy_input(MatrixAlgebraKit.svd_trunc, A)
261+
USVᴴ = MatrixAlgebraKit.initialize_output(svd_trunc!, A, alg)
262+
Utrunc, Strunc, Vᴴtrunc, ϵ = MatrixAlgebraKit.svd_trunc!(Ac, USVᴴ, ϵ, alg)
263+
return Utrunc, Strunc, Vᴴtrunc, ϵ
264+
end
265+
232266
@timedtestset "SVD AD Rules with eltype $T" for T in ETs
233267
rng = StableRNG(12345)
234268
m = 19
@@ -239,7 +273,7 @@ end
239273
@testset for alg in (LAPACK_QRIteration(),
240274
LAPACK_DivideAndConquer(),
241275
)
242-
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
276+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
243277
@testset "svd_compact" begin
244278
U, S, Vᴴ = svd_compact(A)
245279
ΔU = randn(rng, T, m, minmn)
@@ -250,7 +284,7 @@ end
250284
test_reverse(svd_compact, RT, (A, TA); atol=atol, rtol=rtol, fkwargs=(alg=alg,), output_tangent=(ΔU, ΔS, ΔVᴴ), fdm=fdm)
251285
end
252286
end
253-
@testset "reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
287+
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
254288
@testset "svd_trunc" begin
255289
for r in 1:4:minmn
256290
U, S, Vᴴ = svd_compact(A)
@@ -269,7 +303,8 @@ end
269303
ΔVᴴtrunc = ΔVᴴ[ind, :]
270304
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
271305
# broken due to Enzyme
272-
#test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm=fdm)
306+
ϵ = [zero(real(T))]
307+
test_reverse(dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm=fdm)
273308
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
274309
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
275310
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -290,7 +325,8 @@ end
290325
ΔVᴴtrunc = ΔVᴴ[ind, :]
291326
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range=1e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
292327
# broken due to Enzyme
293-
#test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm)
328+
ϵ = [zero(real(T))]
329+
test_reverse(dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm=fdm)
294330
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (copy(U), copy(S), copy(Vᴴ)), (copy(ΔUtrunc), copy(ΔStrunc), copy(ΔVᴴtrunc)), ind)
295331
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
296332
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)

0 commit comments

Comments
 (0)