Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ storage_type(op::AbstractLinearOperator) = error("please implement storage_type
storage_type(op::LinearOperator) = typeof(op.Mv5)
storage_type(M::AbstractMatrix{T}) where {T} = Vector{T}

# Lazy wrappers
storage_type(op::Adjoint) = storage_type(parent(op))
storage_type(op::Transpose) = storage_type(parent(op))
storage_type(op::Diagonal) = typeof(parent(op))

"""
reset!(op)

Expand Down
6 changes: 6 additions & 0 deletions test/gpu/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,11 @@ using LinearOperators, AMDGPU
y = M * v
@test y isa ROCArray{Float32}

@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(transpose(A))
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
@test LinearOperators.storage_type(Diagonal(v)) == typeof(v)


@testset "AMDGPU S kwarg" test_S_kwarg(arrayType = ROCArray)
end
8 changes: 7 additions & 1 deletion test/gpu/nvidia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@ using LinearOperators, CUDA, CUDA.CUSPARSE, CUDA.CUSOLVER
A = CUDA.rand(5, 5)
B = CUDA.rand(10, 10)
C = CUDA.rand(20, 20)
M = BlockDiagonalOperator(A, B, C)
M = BlockDiagonalOperator(A, B, C)

v = CUDA.rand(35)
y = M * v
@test y isa CuVector{Float32}

@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(transpose(A))
@test LinearOperators.storage_type(A) == LinearOperators.storage_type(adjoint(A))
@test LinearOperators.storage_type(Diagonal(v)) == typeof(v)

@testset "Nvidia S kwarg" test_S_kwarg(arrayType = CuArray)
end