Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function LinearAlgebra.transpose!(B::AbstractGPUMatrix, A::AbstractGPUVector)
end
function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix)
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("adjoint"))
isempty(A) && return B
@kernel function adjoint_kernel!(B, A)
idx = @index(Global, Linear)
@inbounds B[idx] = adjoint(A[1, idx])
Expand All @@ -23,6 +24,7 @@ function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix)
end
function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector)
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("adjoint"))
isempty(A) && return B
@kernel function adjoint_kernel!(B, A)
idx = @index(Global, Linear)
@inbounds B[1, idx] = adjoint(A[idx])
Expand All @@ -35,6 +37,8 @@ LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpos
LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A)
function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f)))
# array with size zero dimension
isempty(A) && return B
@kernel function transpose_kernel!(B, A)
idx = @index(Global, Cartesian)
@inbounds B[idx[2], idx[1]] = f(A[idx[1], idx[2]])
Expand Down
4 changes: 4 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
@test compare(adjoint!, AT, rand(Float32, 1, 32), rand(Float32, 32))
@test compare(adjoint!, AT, rand(Float32, 32), rand(Float32, 1, 32))
@test compare(adjoint!, AT, rand(Float32, 32, 0), rand(Float32, 0, 32))
@test compare(adjoint!, AT, rand(Float32, 0, 32), rand(Float32, 32, 0))
@test compare(transpose, AT, rand(Float32, 32, 32))
@test compare(transpose!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
@test compare(transpose!, AT, rand(Float32, 1, 32), rand(Float32, 32))
@test compare(transpose!, AT, rand(Float32, 32), rand(Float32, 1, 32))
@test compare(transpose!, AT, rand(Float32, 32, 0), rand(Float32, 0, 32))
@test compare(transpose!, AT, rand(Float32, 0, 32), rand(Float32, 32, 0))
@test compare((x,y)->copyto!(x, adjoint(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
@test compare((x,y)->copyto!(x, transpose(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
@test compare(transpose!, AT, Array{Float32}(undef, 32, 32), rand(Float32, 32, 32))
Expand Down
Loading