Skip to content

Commit 8d569fc

Browse files
authored
Graceful exit for transpose/adjoint on size 0 arrays (#613)
* Graceful exit for transpose/adjoint on size 0 arrays * Fix for CUDA * Use length rather than axes * use isempty
1 parent c3cd545 commit 8d569fc

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/host/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ function LinearAlgebra.transpose!(B::AbstractGPUMatrix, A::AbstractGPUVector)
1414
end
1515
function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix)
1616
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("adjoint"))
17+
isempty(A) && return B
1718
@kernel function adjoint_kernel!(B, A)
1819
idx = @index(Global, Linear)
1920
@inbounds B[idx] = adjoint(A[1, idx])
@@ -23,6 +24,7 @@ function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix)
2324
end
2425
function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector)
2526
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("adjoint"))
27+
isempty(A) && return B
2628
@kernel function adjoint_kernel!(B, A)
2729
idx = @index(Global, Linear)
2830
@inbounds B[1, idx] = adjoint(A[idx])
@@ -35,6 +37,8 @@ LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpos
3537
LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A)
3638
function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T
3739
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f)))
40+
# array with size zero dimension
41+
isempty(A) && return B
3842
@kernel function transpose_kernel!(B, A)
3943
idx = @index(Global, Cartesian)
4044
@inbounds B[idx[2], idx[1]] = f(A[idx[1], idx[2]])

test/testsuite/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
55
@test compare(adjoint!, AT, rand(Float32, 1, 32), rand(Float32, 32))
66
@test compare(adjoint!, AT, rand(Float32, 32), rand(Float32, 1, 32))
7+
@test compare(adjoint!, AT, rand(Float32, 32, 0), rand(Float32, 0, 32))
8+
@test compare(adjoint!, AT, rand(Float32, 0, 32), rand(Float32, 32, 0))
79
@test compare(transpose, AT, rand(Float32, 32, 32))
810
@test compare(transpose!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
911
@test compare(transpose!, AT, rand(Float32, 1, 32), rand(Float32, 32))
1012
@test compare(transpose!, AT, rand(Float32, 32), rand(Float32, 1, 32))
13+
@test compare(transpose!, AT, rand(Float32, 32, 0), rand(Float32, 0, 32))
14+
@test compare(transpose!, AT, rand(Float32, 0, 32), rand(Float32, 32, 0))
1115
@test compare((x,y)->copyto!(x, adjoint(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
1216
@test compare((x,y)->copyto!(x, transpose(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
1317
@test compare(transpose!, AT, Array{Float32}(undef, 32, 32), rand(Float32, 32, 32))

0 commit comments

Comments
 (0)