Skip to content

Commit 5b9b22f

Browse files
committed
A few more updates for GPU compatibility for TensorKit
1 parent eceef30 commit 5b9b22f

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
127127
return nothing
128128
end
129129

130+
MatrixAlgebraKit.project_hermitian_native!(A::Diagonal{T, <:StridedROCVector}, B::Diagonal{T, <:StridedROCVector}, ::Val{anti}) where {T, anti} = MatrixAlgebraKit.project_hermitian_native!(ROCMatrix(A), ROCMatrix(B), Val(anti))
131+
130132
MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
131133
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
132134
all(A.diag .== adjoint(A.diag))

ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride
151151
return nothing
152152
end
153153

154+
MatrixAlgebraKit.project_hermitian_native!(A::Diagonal{T, <:StridedCuVector}, B::Diagonal{T, <:StridedCuVector}, ::Val{anti}) where {T, anti} = MatrixAlgebraKit.project_hermitian_native!(CuMatrix(A), CuMatrix(B), Val(anti))
155+
154156
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
155157
all(A .== adjoint(A))
156158
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =

src/implementations/projections.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)
99

1010
function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
1111
LinearAlgebra.checksquare(A)
12-
n = Base.require_one_based_indexing(A)
12+
Base.require_one_based_indexing(A)
13+
n = size(A, 1)
1314
B === A || @check_size(B, (n, n))
1415
return nothing
1516
end
1617
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
1718
LinearAlgebra.checksquare(A)
18-
n = Base.require_one_based_indexing(A)
19+
Base.require_one_based_indexing(A)
20+
n = size(A, 1)
1921
B === A || @check_size(B, (n, n))
2022
return nothing
2123
end

0 commit comments

Comments
 (0)