Skip to content

Commit 9698614

Browse files
committed
WIP Enzyme and Mooncake rules
1 parent 4fbc3bf commit 9698614

File tree

23 files changed

+2540
-88
lines changed

23 files changed

+2540
-88
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,27 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
14+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1315

1416
[extensions]
1517
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
1618
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1719
MatrixAlgebraKitCUDAExt = "CUDA"
20+
MatrixAlgebraKitEnzymeExt = "Enzyme"
21+
MatrixAlgebraKitMooncakeExt = "Mooncake"
1822

1923
[compat]
2024
AMDGPU = "2"
2125
Aqua = "0.6, 0.7, 0.8"
2226
ChainRulesCore = "1"
2327
ChainRulesTestUtils = "1"
2428
CUDA = "5"
29+
Enzyme = "0.13.77"
30+
EnzymeTestUtils = "0.2.3"
2531
JET = "0.9, 0.10"
2632
LinearAlgebra = "1"
33+
Mooncake = "0.4.167"
2734
SafeTestsets = "0.1"
2835
StableRNGs = "1"
2936
Test = "1"
@@ -34,12 +41,14 @@ julia = "1.10"
3441
[extras]
3542
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3643
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
3745
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
46+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3847
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3948
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4049
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4150
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
4251
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4352

4453
[targets]
45-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
54+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "Enzyme", "EnzymeTestUtils", "Mooncake"]

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 713 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
module MatrixAlgebraKitMooncakeExt
2+
3+
using Mooncake
4+
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
5+
using MatrixAlgebraKit
6+
using MatrixAlgebraKit: inv_safe, diagview
7+
using MatrixAlgebraKit: svd_pullfwd!
8+
using MatrixAlgebraKit: qr_pullback!, lq_pullback!, qr_pullfwd!, lq_pullfwd!
9+
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!, qr_null_pullfwd!, lq_null_pullfwd!
10+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_pullfwd!, eigh_pullfwd!
11+
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pullfwd!, right_polar_pullfwd!
12+
using LinearAlgebra
13+
14+
# two-argument factorizations like LQ, QR, EIG
15+
for (f, pb, pf, adj) in ((qr_full!, qr_pullback!, qr_pullfwd!, :dqr_adjoint),
16+
(qr_compact!, qr_pullback!, qr_pullfwd!, :dqr_adjoint),
17+
(lq_full!, lq_pullback!, lq_pullfwd!, :dlq_adjoint),
18+
(lq_compact!, lq_pullback!, lq_pullfwd!, :dlq_adjoint),
19+
(eig_full!, eig_pullback!, eig_pullfwd!, :deig_adjoint),
20+
(eigh_full!, eigh_pullback!, eigh_pullfwd!, :deigh_adjoint),
21+
(left_polar!, left_polar_pullback!, left_polar_pullfwd!, :dleft_polar_adjoint),
22+
(right_polar!, right_polar_pullback!, right_polar_pullfwd!, :dright_polar_adjoint),
23+
)
24+
25+
@eval begin
26+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
27+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
28+
A, dA = arrayify(A_dA)
29+
dA .= zero(eltype(A))
30+
args = Mooncake.primal(args_dargs)
31+
dargs = Mooncake.tangent(args_dargs)
32+
arg1, darg1 = arrayify(args[1], dargs[1])
33+
arg2, darg2 = arrayify(args[2], dargs[2])
34+
function $adj(::Mooncake.NoRData)
35+
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
36+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
37+
end
38+
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
39+
darg1 .= zero(eltype(arg1))
40+
darg2 .= zero(eltype(arg2))
41+
return Mooncake.CoDual(args, dargs), $adj
42+
end
43+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
44+
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual{<:AbstractMatrix}, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
45+
A, dA = arrayify(A_dA)
46+
args = Mooncake.primal(args_dargs)
47+
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
48+
dargs = Mooncake.tangent(args_dargs)
49+
arg1, darg1 = arrayify(args[1], dargs[1])
50+
arg2, darg2 = arrayify(args[2], dargs[2])
51+
darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))
52+
dA .= zero(eltype(A))
53+
return Mooncake.Dual(args, dargs)
54+
end
55+
end
56+
end
57+
58+
for (f, pb, pf, adj) in ((qr_null!, qr_null_pullback!, qr_null_pullfwd!, :dqr_null_adjoint),
59+
(lq_null!, lq_null_pullback!, lq_null_pullfwd!, :dlq_null_adjoint),
60+
)
61+
@eval begin
62+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
63+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
64+
A, dA = arrayify(A_dA)
65+
Ac = MatrixAlgebraKit.copy_input(lq_full, A)
66+
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
67+
arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
68+
function $adj(::Mooncake.NoRData)
69+
dA .= zero(eltype(A))
70+
$pb(dA, A, arg, darg; kwargs...)
71+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
72+
end
73+
return arg_darg, $adj
74+
end
75+
#forward mode not implemented yet
76+
end
77+
end
78+
79+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
80+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
81+
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
82+
# compute primal
83+
D_ = Mooncake.primal(D_dD)
84+
dD_ = Mooncake.tangent(D_dD)
85+
A_ = Mooncake.primal(A_dA)
86+
dA_ = Mooncake.tangent(A_dA)
87+
A, dA = arrayify(A_, dA_)
88+
D, dD = arrayify(D_, dD_)
89+
nD, V = eig_full(A, alg_dalg.primal; kwargs...)
90+
91+
# update tangent
92+
tmp = V \ dA
93+
dD .= diagview(tmp * V)
94+
dA .= zero(eltype(dA))
95+
return Mooncake.Dual(nD.diag, dD_)
96+
end
97+
98+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
99+
# compute primal
100+
D_ = Mooncake.primal(D_dD)
101+
dD_ = Mooncake.tangent(D_dD)
102+
A_ = Mooncake.primal(A_dA)
103+
dA_ = Mooncake.tangent(A_dA)
104+
A, dA = arrayify(A_, dA_)
105+
D, dD = arrayify(D_, dD_)
106+
dA .= zero(eltype(dA))
107+
# update primal
108+
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
109+
V = DV[2]
110+
dD .= zero(eltype(D))
111+
function deig_vals_adjoint(::Mooncake.NoRData)
112+
PΔV = V' \ Diagonal(dD)
113+
if eltype(dA) <: Real
114+
ΔAc = PΔV * V'
115+
dA .+= real.(ΔAc)
116+
else
117+
mul!(dA, PΔV, V', 1, 0)
118+
end
119+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
120+
end
121+
return Mooncake.CoDual(DV[1].diag, dD_), deig_vals_adjoint
122+
end
123+
#=
124+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
125+
function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::CoDual{<:AbstractMatrix}, DV_dDV::CoDual{<:Tuple{<:Diagonal, <:AbstractMatrix}}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
126+
A, dA = arrayify(A_dA)
127+
dA .= zero(eltype(A))
128+
DV = Mooncake.primal(DV_dDV)
129+
dDV = Mooncake.tangent(DV_dDV)
130+
D, dD = arrayify(DV[1], dDV[1])
131+
V, dV = arrayify(DV[2], dDV[2])
132+
function deigh_adjoint(::Mooncake.NoRData)
133+
dA = MatrixAlgebraKit.eigh_pullback!(dA, A, (D, V), (dD, dV); kwargs...)
134+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
135+
end
136+
DV = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...)
137+
return Mooncake.CoDual(DV, dDV), deigh_adjoint
138+
end
139+
140+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_full!), AbstractMatrix, Tuple{<:Diagonal, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
141+
function Mooncake.frule!!(::Dual{typeof(MatrixAlgebraKit.eigh_full!)}, A_dA::Dual, DV_dDV::Dual, alg_dalg::Dual; kwargs...)
142+
A, dA = arrayify(A_dA)
143+
DV = Mooncake.primal(DV_dDV)
144+
dDV = Mooncake.tangent(DV_dDV)
145+
D, dD = arrayify(DV[1], dDV[1])
146+
V, dV = arrayify(DV[2], dDV[2])
147+
(D, V) = eigh_full!(A, DV, Mooncake.primal(alg_dalg); kwargs...)
148+
(dD, dV) = eigh_pullfwd!(dA, A, (D, V), (dD, dV); kwargs...)
149+
return Mooncake.Dual(DV, dDV)
150+
end
151+
=#
152+
153+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
154+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
155+
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual; kwargs...)
156+
# compute primal
157+
D_ = Mooncake.primal(D_dD)
158+
dD_ = Mooncake.tangent(D_dD)
159+
A_ = Mooncake.primal(A_dA)
160+
dA_ = Mooncake.tangent(A_dA)
161+
A, dA = arrayify(A_, dA_)
162+
D, dD = arrayify(D_, dD_)
163+
nD, V = eigh_full(A, alg_dalg.primal; kwargs...)
164+
# update tangent
165+
tmp = inv(V) * dA * V
166+
dD .= real.(diagview(tmp))
167+
D .= nD.diag
168+
dA .= zero(eltype(dA))
169+
return D_dD
170+
end
171+
172+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
173+
# compute primal
174+
D_ = Mooncake.primal(D_dD)
175+
dD_ = Mooncake.tangent(D_dD)
176+
A_ = Mooncake.primal(A_dA)
177+
dA_ = Mooncake.tangent(A_dA)
178+
A, dA = arrayify(A_, dA_)
179+
D, dD = arrayify(D_, dD_)
180+
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
181+
function deigh_vals_adjoint(::Mooncake.NoRData)
182+
mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
183+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
184+
end
185+
return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint
186+
end
187+
188+
189+
for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
190+
@eval begin
191+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
192+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
193+
A, dA = arrayify(A_dA)
194+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
195+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
196+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
197+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
198+
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
199+
USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
200+
function dsvd_adjoint(::Mooncake.NoRData)
201+
dA .= zero(eltype(A))
202+
minmn = min(size(A)...)
203+
if size(U, 2) == size(Vᴴ, 1) == minmn # compact
204+
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
205+
else # full
206+
vU = view(U, :, 1:minmn)
207+
vS = Diagonal(diagview(S)[1:minmn])
208+
vVᴴ = view(Vᴴ, 1:minmn, :)
209+
vdU = view(dU, :, 1:minmn)
210+
vdS = Diagonal(diagview(dS)[1:minmn])
211+
vdVᴴ = view(dVᴴ, 1:minmn, :)
212+
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (vdU, vdS, vdVᴴ))
213+
end
214+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
215+
end
216+
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
217+
end
218+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
219+
function Mooncake.frule!!(::Dual{<:typeof($f)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual; kwargs...)
220+
# compute primal
221+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
222+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
223+
A_ = Mooncake.primal(A_dA)
224+
dA_ = Mooncake.tangent(A_dA)
225+
A, dA = arrayify(A_, dA_)
226+
$f(A, USVᴴ, alg_dalg.primal; kwargs...)
227+
228+
# update tangents
229+
U_, S_, Vᴴ_ = USVᴴ
230+
dU_, dS_, dVᴴ_ = dUSVᴴ
231+
U, dU = arrayify(U_, dU_)
232+
S, dS = arrayify(S_, dS_)
233+
Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
234+
(dU, dS, dVᴴ) = svd_pullfwd!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ); kwargs...)
235+
return USVᴴ_dUSVᴴ
236+
end
237+
end
238+
end
239+
240+
@is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
241+
function Mooncake.frule!!(::Dual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual; kwargs...)
242+
# compute primal
243+
S_ = Mooncake.primal(S_dS)
244+
dS_ = Mooncake.tangent(S_dS)
245+
A_ = Mooncake.primal(A_dA)
246+
dA_ = Mooncake.tangent(A_dA)
247+
A, dA = arrayify(A_, dA_)
248+
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
249+
250+
# update tangent
251+
S, dS = arrayify(S_, dS_)
252+
copyto!(dS, diag(real.(Vᴴ * dA' * U)))
253+
copyto!(S, diagview(nS))
254+
dA .= zero(eltype(dA))
255+
return Mooncake.Dual(nS.diag, dS)
256+
end
257+
258+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
259+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
260+
# compute primal
261+
S_ = Mooncake.primal(S_dS)
262+
dS_ = Mooncake.tangent(S_dS)
263+
A_ = Mooncake.primal(A_dA)
264+
dA_ = Mooncake.tangent(A_dA)
265+
A, dA = arrayify(A_, dA_)
266+
S, dS = arrayify(S_, dS_)
267+
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
268+
S .= diagview(nS)
269+
dS .= zero(eltype(S))
270+
function dsvd_vals_adjoint(::Mooncake.NoRData)
271+
dA .= U * Diagonal(dS) * Vᴴ
272+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
273+
end
274+
return S_dS, dsvd_vals_adjoint
275+
end
276+
277+
end

src/MatrixAlgebraKit.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,11 @@ include("pullbacks/eigh.jl")
111111
include("pullbacks/svd.jl")
112112
include("pullbacks/polar.jl")
113113

114+
include("pullfwds/qr.jl")
115+
include("pullfwds/lq.jl")
116+
include("pullfwds/eig.jl")
117+
include("pullfwds/eigh.jl")
118+
include("pullfwds/polar.jl")
119+
include("pullfwds/svd.jl")
120+
114121
end

src/common/view.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# diagind: provided by LinearAlgebra.jl
2-
diagview(D::Diagonal) = D.diag
2+
diagview(D::Diagonal) = D.diag
33
diagview(D::AbstractMatrix) = view(D, diagind(D))
44

55
# triangularind

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
1919
end
2020

2121
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22-
check_hermitian(A, alg)
22+
#check_hermitian(A, alg)
2323
D, V = DV
2424
m = size(A, 1)
2525
@assert D isa Diagonal && V isa AbstractMatrix

0 commit comments

Comments
 (0)