Skip to content

Commit efcb33e

Browse files
committed
Fix implementations of eigen and eigvals
1 parent d813023 commit efcb33e

File tree

4 files changed

+150
-67
lines changed

4 files changed

+150
-67
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,28 @@ end
2424
@inline static_dual_eval(::Type{T}, f::F, x::StaticArray) where {T,F} = f(dualize(T, x))
2525

2626
# To fix method ambiguity issues:
27-
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
28-
return ForwardDiff._eigvals(A)
27+
function LinearAlgebra.eigvals(A::Symmetric{Dual{T,V,N}, <:StaticArrays.StaticMatrix}) where {T,V<:Real,N}
28+
return ForwardDiff._eigvals_hermitian(A)
2929
end
30-
function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
31-
return ForwardDiff._eigen(A)
30+
function LinearAlgebra.eigvals(A::Hermitian{Dual{T,V,N}, <:StaticArrays.StaticMatrix}) where {T,V<:Real,N}
31+
return ForwardDiff._eigvals_hermitian(A)
32+
end
33+
function LinearAlgebra.eigvals(A::Hermitian{Complex{Dual{T,V,N}}, <:StaticArrays.StaticMatrix}) where {T,V<:Real,N}
34+
return ForwardDiff._eigvals_hermitian(A)
35+
end
36+
37+
function LinearAlgebra.eigen(A::Symmetric{Dual{T,V,N}, <:StaticArrays.StaticMatrix}) where {T,V<:Real,N}
38+
return ForwardDiff._eigen_hermitian(A)
39+
end
40+
function LinearAlgebra.eigen(A::Hermitian{Dual{T,V,N}, <:StaticArrays.StaticMatrix}) where {T,V<:Real,N}
41+
return ForwardDiff._eigen_hermitian(A)
42+
end
43+
function LinearAlgebra.eigen(A::Hermitian{Complex{Dual{T,V,N}}, <:StaticArrays.StaticMatrix}) where {T,V<:Real,N}
44+
return ForwardDiff._eigen_hermitian(A)
3245
end
3346

3447
# For `MMatrix` we can use the in-place method
35-
ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDiff._lyap_div!(A, λ)
48+
ForwardDiff._lyap_div_zero_diag!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDiff._lyap_div_zero_diag!(A, λ)
3649

3750
# Gradient
3851
@inline ForwardDiff.gradient(f::F, x::StaticArray) where {F} = vector_mode_gradient(f, x)

src/dual.jl

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -735,68 +735,86 @@ end
735735
return (Dual{T}(sd, cd * π * partials(d)), Dual{T}(cd, -sd * π * partials(d)))
736736
end
737737

738-
# Symmetric eigvals #
738+
# eigen values and vectors of Hermitian matrices #
739739
#-------------------#
740740

741-
# To be able to reuse this default definition in the StaticArrays extension
742-
# (has to be re-defined to avoid method ambiguity issues)
743-
# we forward the call to an internal method that can be shared and reused
744-
LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} = _eigvals(A)
745-
function _eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
746-
λ,Q = eigen(Symmetric(value.(parent(A))))
747-
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
748-
Dual{Tg}.(λ, tuple.(parts...))
741+
# Extract structured matrices of primal values and partials
742+
_structured_value(A::Symmetric{Dual{T,V,N}}) where {T,V,N} = Symmetric(map(value, parent(A)), A.uplo === 'U' ? :U : :L)
743+
_structured_value(A::Hermitian{Dual{T,V,N}}) where {T,V,N} = Hermitian(map(value, parent(A)), A.uplo === 'U' ? :U : :L)
744+
_structured_value(A::Hermitian{Complex{Dual{T,V,N}}}) where {T,V,N} = Hermitian(map(z -> splat(complex)(map(value, reim(z))), parent(A)), A.uplo === 'U' ? :U : :L)
745+
_structured_value(A::SymTridiagonal{Dual{T,V,N}}) where {T,V,N} = SymTridiagonal(map(value, A.dv), map(value, A.ev))
746+
747+
_structured_partials(A::Symmetric{Dual{T,V,N}}, j::Int) where {T,V,N} = Symmetric(partials.(parent(A), j), A.uplo === 'U' ? :U : :L)
748+
_structured_partials(A::Hermitian{Dual{T,V,N}}, j::Int) where {T,V,N} = Hermitian(partials.(parent(A), j), A.uplo === 'U' ? :U : :L)
749+
function _structured_partials(A::Hermitian{Complex{Dual{T,V,N}}}, j::Int) where {T,V,N}
750+
return Hermitian(complex.(partials.(real.(parent(A)), j), partials.(imag.(parent(A)), j)), A.uplo === 'U' ? :U : :L)
749751
end
752+
_structured_partials(A::SymTridiagonal{Dual{T,V,N}}, j::Int) where {T,V,N} = SymTridiagonal(partials.(A.dv, j), partials.(A.ev, j))
750753

751-
function LinearAlgebra.eigvals(A::Hermitian{<:Complex{<:Dual{Tg,T,N}}}) where {Tg,T<:Real,N}
752-
λ,Q = eigen(Hermitian(value.(real.(parent(A))) .+ im .* value.(imag.(parent(A)))))
753-
parts = ntuple(j -> diag(real.(Q' * (getindex.(partials.(real.(A)) .+ im .* partials.(imag.(A)), j)) * Q)), N)
754-
Dual{Tg}.(λ, tuple.(parts...))
754+
# Convert arrays of primal values and partials to arrays of Duals
755+
function _to_duals(::Val{T}, values::AbstractArray{<:Real}, partials::Tuple{Vararg{AbstractArray{<:Real}}}) where {T}
756+
return Dual{T}.(values, tuple.(partials...))
757+
end
758+
function _to_duals(::Val{T}, values::AbstractArray{<:Complex}, partials::Tuple{Vararg{AbstractArray{<:Complex}}}) where {T}
759+
return complex.(
760+
Dual{T}.(real.(values), Base.Fix1(map, real).(tuple.(partials...))),
761+
Dual{T}.(imag.(values), Base.Fix1(map, imag).(tuple.(partials...))),
762+
)
755763
end
756764

757-
function LinearAlgebra.eigvals(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
758-
λ,Q = eigen(SymTridiagonal(value.(parent(A).dv),value.(parent(A).ev)))
759-
parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N)
760-
Dual{Tg}.(λ, tuple.(parts...))
765+
# We forward the call to an internal method that can be shared and reused
766+
LinearAlgebra.eigvals(A::Symmetric{Dual{T,V,N}}) where {T,V<:Real,N} = _eigvals_hermitian(A)
767+
LinearAlgebra.eigvals(A::Hermitian{Dual{T,V,N}}) where {T,V<:Real,N} = _eigvals_hermitian(A)
768+
LinearAlgebra.eigvals(A::Hermitian{Complex{Dual{T,V,N}}}) where {T,V<:Real,N} = _eigvals_hermitian(A)
769+
LinearAlgebra.eigvals(A::SymTridiagonal{Dual{T,V,N}}) where {T,V<:Real,N} = _eigvals_hermitian(A)
770+
771+
# Eigenvalues of Hermitian-structured matrices
772+
const DualMatrixRealComplex{T,V<:Real,N} = Union{AbstractMatrix{Dual{T,V,N}}, AbstractMatrix{Complex{Dual{T,V,N}}}}
773+
function _eigvals_hermitian(A::DualMatrixRealComplex{T,<:Real,N}) where {T,N}
774+
F = eigen(_structured_value(A))
775+
λ = F.values
776+
Q = F.vectors
777+
parts = ntuple(j -> real(diag(Q' * _structured_partials(A, j) * Q)), N)
778+
return _to_duals(Val(T), λ, parts)
761779
end
762780

763-
# A ./ (λ' .- λ) but with diag special cased
781+
# A ./ (λ' .- λ) but with diagonal elements zeroed out
764782
# Default out-of-place method
765-
function _lyap_div!!(A::AbstractMatrix, λ::AbstractVector)
783+
function _lyap_div_zero_diag!!(A::AbstractMatrix, λ::AbstractVector)
766784
return map(
767-
(a, b, idx) -> a / (idx[1] == idx[2] ? oneunit(b) : b),
785+
(a, b, idx) -> idx[1] == idx[2] ? zero(a) / oneunit(b) : a / b,
768786
A,
769787
λ' .- λ,
770788
CartesianIndices(A),
771789
)
772790
end
773791
# For `Matrix` (and e.g. `StaticArrays.MMatrix`) we can use an in-place method
774-
_lyap_div!!(A::Matrix, λ::AbstractVector) = _lyap_div!(A, λ)
775-
function _lyap_div!(A::AbstractMatrix, λ::AbstractVector)
792+
_lyap_div_zero_diag!!(A::Matrix, λ::AbstractVector) = _lyap_div_zero_diag!(A, λ)
793+
function _lyap_div_zero_diag!(A::AbstractMatrix, λ::AbstractVector)
776794
for (j,μ) in enumerate(λ), (k,λ) in enumerate(λ)
777-
if k j
795+
if k == j
796+
A[k, j] = zero(A[k, j])
797+
else
778798
A[k,j] /= μ - λ
779799
end
780800
end
781801
A
782802
end
783803

784-
# To be able to reuse this default definition in the StaticArrays extension
785-
# (has to be re-defined to avoid method ambiguity issues)
786-
# we forward the call to an internal method that can be shared and reused
787-
LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} = _eigen(A)
788-
function _eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
789-
λ = eigvals(A)
790-
_,Q = eigen(Symmetric(value.(parent(A))))
791-
parts = ntuple(j -> Q*_lyap_div!!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
792-
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
793-
end
794-
795-
function LinearAlgebra.eigen(A::SymTridiagonal{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N}
796-
λ = eigvals(A)
797-
_,Q = eigen(SymTridiagonal(value.(parent(A))))
798-
parts = ntuple(j -> Q*_lyap_div!!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N)
799-
Eigen(λ,Dual{Tg}.(Q, tuple.(parts...)))
804+
# We forward the call to an internal method that can be shared and reused
805+
LinearAlgebra.eigen(A::Symmetric{Dual{T,V,N}}) where {T,V<:Real,N} = _eigen_hermitian(A)
806+
LinearAlgebra.eigen(A::Hermitian{Dual{T,V,N}}) where {T,V<:Real,N} = _eigen_hermitian(A)
807+
LinearAlgebra.eigen(A::Hermitian{Complex{Dual{T,V,N}}}) where {T,V<:Real,N} = _eigen_hermitian(A)
808+
LinearAlgebra.eigen(A::SymTridiagonal{Dual{T,V,N}}) where {T,V<:Real,N} = _eigen_hermitian(A)
809+
810+
function _eigen_hermitian(A::DualMatrixRealComplex{T,<:Real,N}) where {T,N}
811+
F = eigen(_structured_value(A))
812+
λ = F.values
813+
Q = F.vectors
814+
Qt_∂A_Q = ntuple(j -> Q' * _structured_partials(A, j) * Q, N)
815+
λ_partials = map(real diag, Qt_∂A_Q)
816+
Q_partials = map(Qt_∂Aj_Q -> Q*_lyap_div_zero_diag!!(Qt_∂Aj_Q, λ), Qt_∂A_Q)
817+
return Eigen(_to_duals(Val(T), λ, λ_partials), _to_duals(Val(T), Q, Q_partials))
800818
end
801819

802820
# Functions in SpecialFunctions which return tuples #

test/JacobianTest.jl

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -238,22 +238,74 @@ end
238238
end
239239

240240
@testset "eigen" begin
241-
@test ForwardDiff.jacobian(x -> eigvals(SymTridiagonal(x, x[1:end-1])), [1.,2.]) [(1 - 3/sqrt(5))/2 (1 - 1/sqrt(5))/2 ; (1 + 3/sqrt(5))/2 (1 + 1/sqrt(5))/2]
242-
@test ForwardDiff.jacobian(x -> eigvals(Symmetric(x*x')), [1.,2.]) [0 0; 2 4]
243-
244-
x0 = [1.0, 2.0];
245-
ev1(x) = eigen(Symmetric(x*x')).vectors[:,1]
246-
@test ForwardDiff.jacobian(ev1, x0) Calculus.finite_difference_jacobian(ev1, x0)
247-
ev2(x) = eigen(SymTridiagonal(x, x[1:end-1])).vectors[:,1]
248-
@test ForwardDiff.jacobian(ev2, x0) Calculus.finite_difference_jacobian(ev2, x0)
249-
250-
x0_svector = SVector{2}(x0)
251-
@test ForwardDiff.jacobian(ev1, x0_svector) isa SMatrix{2, 2}
252-
@test ForwardDiff.jacobian(ev1, x0_svector) Calculus.finite_difference_jacobian(ev1, x0)
253-
254-
x0_mvector = MVector{2}(x0)
255-
@test ForwardDiff.jacobian(ev1, x0_mvector) isa MMatrix{2, 2}
256-
@test ForwardDiff.jacobian(ev1, x0_mvector) Calculus.finite_difference_jacobian(ev1, x0)
241+
eigvals_symreal(x) = eigvals(Symmetric(x*x'))
242+
eigvals_hermreal(x) = eigvals(Hermitian(x*x'))
243+
eigvals_hermcomplex(x) = map(abs2, eigvals(Hermitian(complex.(x*x', x'*x))))
244+
eigvals_symtridiag(x) = eigvals(SymTridiagonal(x, x[begin:(end - 1)]))
245+
246+
eigen_vals_symreal(x) = eigen(Symmetric(x*x')).values
247+
eigen_vals_hermreal(x) = eigen(Hermitian(x*x')).values
248+
eigen_vals_hermcomplex(x) = map(abs2, eigen(Hermitian(complex.(x*x', x'*x))).values)
249+
eigen_vals_symtridiag(x) = eigen(SymTridiagonal(x, x[begin:(end - 1)])).values
250+
251+
eigen_vec1_symreal(x) = eigen(Symmetric(x*x')).vectors[:,1]
252+
eigen_vec1_hermreal(x) = eigen(Hermitian(x*x')).vectors[:,1]
253+
eigen_vec1_hermcomplex(x) = map(abs2, eigen(Hermitian(complex.(x*x', x'*x))).vectors[:,1])
254+
eigen_vec1_symtridiag(x) = eigen(SymTridiagonal(x, x[begin:(end - 1)])).vectors[:,1]
255+
256+
# Note: SymTridiagonal is not supported for StaticArrays
257+
for T in (Int, Float32, Float64), A in (Array, SArray, MArray)
258+
if A <: StaticArrays.StaticArray
259+
x = A{Tuple{2},T}(x0)
260+
JT = A{Tuple{2,2},float(T)}
261+
else
262+
x = A{T,1}(x0)
263+
JT = A{float(T),2}
264+
end
265+
266+
# analytic solutions
267+
@test ForwardDiff.jacobian(eigvals_symreal, x) [0 0; 2 4]
268+
@test ForwardDiff.jacobian(eigvals_hermreal, x) [0 0; 2 4]
269+
if !(x isa StaticArrays.StaticArray)
270+
@test ForwardDiff.jacobian(eigvals_symtridiag, x) [(1 - 3/sqrt(5))/2 (1 - 1/sqrt(5))/2 ; (1 + 3/sqrt(5))/2 (1 + 1/sqrt(5))/2]
271+
end
272+
273+
# eigen + eigvals
274+
for ev in (
275+
eigvals_symreal, eigvals_hermreal, eigvals_hermcomplex, eigvals_symtridiag,
276+
eigen_vals_symreal, eigen_vals_hermreal, eigen_vals_hermcomplex, eigen_vals_symtridiag,
277+
eigen_vec1_symreal, eigen_vec1_hermreal, eigen_vec1_hermcomplex, eigen_vec1_symtridiag,
278+
)
279+
if x isa StaticArrays.StaticArray &&
280+
(ev === eigvals_symtridiag || ev === eigen_vals_symtridiag || ev === eigen_vec1_symtridiag)
281+
continue
282+
end
283+
284+
# Chunk size can only be inferred for static arrays
285+
if x isa StaticArrays.StaticArray
286+
@test @inferred(ForwardDiff.jacobian(ev, x)) isa JT
287+
else
288+
@test ForwardDiff.jacobian(ev, x) isa JT
289+
end
290+
cfg = ForwardDiff.JacobianConfig(ev, x)
291+
@test @inferred(ForwardDiff.jacobian(ev, x, cfg)) isa JT
292+
293+
@test ForwardDiff.jacobian(ev, x) Calculus.finite_difference_jacobian(ev, float.(x0))
294+
end
295+
296+
# consistency of eigen and eigvals
297+
for (eigvals, eigen_vals) in (
298+
(eigvals_symreal, eigen_vals_symreal),
299+
(eigvals_hermreal, eigen_vals_hermreal),
300+
(eigvals_hermcomplex, eigen_vals_hermcomplex),
301+
(eigvals_symtridiag, eigen_vals_symtridiag),
302+
)
303+
if x isa StaticArrays.StaticArray && eigvals === eigvals_symtridiag
304+
continue
305+
end
306+
@test ForwardDiff.jacobian(eigvals, x) ForwardDiff.jacobian(eigen_vals, x)
307+
end
308+
end
257309
end
258310

259311
@testset "type stability" begin

test/MiscTest.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,39 +182,39 @@ end
182182
# example from https://github.com/JuliaDiff/DiffRules.jl/pull/98#issuecomment-1574420052
183183
@test only(ForwardDiff.hessian(t -> abs(t[1])^2, [0.0])) == 2
184184

185-
@testset "_lyap_div!!" begin
185+
@testset "_lyap_div_zero_diag!!" begin
186186
# In-place version for `Matrix`
187187
A = rand(3, 3)
188188
Acopy = copy(A)
189189
λ = rand(3)
190-
B = @inferred(ForwardDiff._lyap_div!!(A, λ))
190+
B = @inferred(ForwardDiff._lyap_div_zero_diag!!(A, λ))
191191
@test B === A
192-
@test B[diagind(B)] == Acopy[diagind(Acopy)]
192+
@test iszero(B[diagind(B)])
193193
no_diag(X) = [X[i] for i in eachindex(X) if !(i in diagind(X))]
194194
@test no_diag(B) == no_diag(Acopy ./' .- λ))
195195

196196
# Immutable static arrays
197197
A_smatrix = SMatrix{3,3}(Acopy)
198198
λ_svector = SVector{3}(λ)
199-
B_smatrix = @inferred(ForwardDiff._lyap_div!!(A_smatrix, λ_svector))
199+
B_smatrix = @inferred(ForwardDiff._lyap_div_zero_diag!!(A_smatrix, λ_svector))
200200
@test B_smatrix !== A_smatrix
201201
@test B_smatrix isa SMatrix{3,3}
202202
@test B_smatrix == B
203203
λ_mvector = MVector{3}(λ)
204-
B_smatrix = @inferred(ForwardDiff._lyap_div!!(A_smatrix, λ_mvector))
204+
B_smatrix = @inferred(ForwardDiff._lyap_div_zero_diag!!(A_smatrix, λ_mvector))
205205
@test B_smatrix !== A_smatrix
206206
@test B_smatrix isa SMatrix{3,3}
207207
@test B_smatrix == B
208208

209209
# Mutable static arrays
210210
A_mmatrix = MMatrix{3,3}(Acopy)
211211
λ_svector = SVector{3}(λ)
212-
B_mmatrix = @inferred(ForwardDiff._lyap_div!!(A_mmatrix, λ_svector))
212+
B_mmatrix = @inferred(ForwardDiff._lyap_div_zero_diag!!(A_mmatrix, λ_svector))
213213
@test B_mmatrix === A_mmatrix
214214
@test B_mmatrix == B
215215
A_mmatrix = MMatrix{3,3}(Acopy)
216216
λ_mvector = MVector{3}(λ)
217-
B_mmatrix = @inferred(ForwardDiff._lyap_div!!(A_mmatrix, λ_mvector))
217+
B_mmatrix = @inferred(ForwardDiff._lyap_div_zero_diag!!(A_mmatrix, λ_mvector))
218218
@test B_mmatrix === A_mmatrix
219219
@test B_mmatrix == B
220220
end

0 commit comments

Comments
 (0)