Skip to content

Commit a55a95c

Browse files
committed
Respond to comments
1 parent 1ba51e9 commit a55a95c

File tree

3 files changed

+114
-39
lines changed

3 files changed

+114
-39
lines changed

README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,61 @@ This package is the counterpart of Julia's `AbstractArray` interface, but for GP
3131
types: It provides functionality and tooling to speed-up development of new GPU array types.
3232
**This package is not intended for end users!** Instead, you should use one of the packages
3333
that builds on GPUArrays.jl, such as [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [oneAPI.jl](https://github.com/JuliaGPU/oneAPI.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), or [Metal.jl](https://github.com/JuliaGPU/Metal.jl).
34+
35+
## Interface methods
36+
37+
To support a new GPU backend, you will need to implement various interface methods for your backend's array types.
38+
Some (CPU based) examples can be see in the testing library `JLArrays` (located in the `lib` directory of this package).
39+
40+
### Dense array support
41+
42+
### Sparse array support (optional)
43+
44+
`GPUArrays.jl` provides **device-side** array types for `CSC`, `CSR`, `COO`, and `BSR` matrices, as well as sparse vectors.
45+
It also provides abstract types for these layouts that you can create concrete child types of in order to benefit from the
46+
backend-agnostic wrappers. In particular, `GPUArrays.jl` provides out-of-the-box support for broadcasting and `mapreduce` over
47+
GPU sparse arrays.
48+
49+
For **host-side** types, your custom sparse types should implement:
50+
51+
- `dense_array_type` - the corresponding dense array type. For example, for a `CuSparseVector` or `CuSparseMatrixCXX`, the `dense_array_type` is `CuArray`
52+
- `sparse_array_type` - the **untyped** sparse array type corresponding to a given parametrized type. A `CuSparseVector{Tv, Ti}` would have a `sparse_array_type` of `CuVector` -- note the lack of type parameters!
53+
- `csc_type(::Type{T})` - the compressed sparse column type for your backend. A `CuSparseMatrixCSR` would have a `csc_type` of `CuSparseMatrixCSC`.
54+
- `csr_type(::Type{T})` - the compressed sparse row type for your backend. A `CuSparseMatrixCSC` would have a `csr_type` of `CuSparseMatrixCSR`.
55+
- `coo_type(::Type{T})` - the coordinate sparse matrix type for your backend. A `CuSparseMatrixCSC` would have a `coo_type` of `CuSparseMatrixCOO`.
56+
57+
Additionally, you need to teach `GPUArrays.jl` how to translate your backend's specific types onto the device. `GPUArrays.jl` provides the device-side types:
58+
59+
- `GPUSparseDeviceVector`
60+
- `GPUSparseDeviceMatrixCSC`
61+
- `GPUSparseDeviceMatrixCSR`
62+
- `GPUSparseDeviceMatrixBSR`
63+
- `GPUSparseDeviceMatrixCOO`
64+
65+
You will need to create a method of `Adapt.adapt_structure` for each format your backend supports. **Note** that if your backend supports separate address spaces,
66+
as CUDA and ROCm do, you need to provide a parameter to these device-side arrays to indicate in which address space the underlying pointers live. An example of adapting
67+
an array to the device-side struct:
68+
69+
```julia
70+
function GPUArrays.GPUSparseDeviceVector(iPtr::MyDeviceVector{Ti, A},
71+
nzVal::MyDeviceVector{Tv, A},
72+
len::Int,
73+
nnz::Ti) where {Ti, Tv, A}
74+
GPUArrays.GPUSparseDeviceVector{Tv, Ti, MyDeviceVector{Ti, A}, MyDeviceVector{Tv, A}, A}(iPtr, nzVal, len, nnz)
75+
end
76+
77+
function Adapt.adapt_structure(to::MyAdaptor, x::MySparseVector)
78+
return GPUArrays.GPUSparseDeviceVector(
79+
adapt(to, x.iPtr),
80+
adapt(to, x.nzVal),
81+
length(x), x.nnz
82+
)
83+
end
84+
```
85+
86+
You'll also need to inform `GPUArrays.jl` and `GPUCompiler.jl` how to adapt your sparse arrays by extending `KernelAbstractions.jl`'s `get_backend()`:
87+
88+
```julia
89+
KA.get_backend(::MySparseVector) = MyBackend()
90+
```
91+
```

lib/JLArrays/src/JLArrays.jl

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using GPUArrays
1313
using Adapt
1414
using SparseArrays, LinearAlgebra
1515

16-
import GPUArrays: _dense_array_type
16+
import GPUArrays: dense_array_type
1717

1818
import KernelAbstractions
1919
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
@@ -198,27 +198,27 @@ function Base.getindex(A::JLSparseMatrixCSR{Tv, Ti}, i0::Integer, i1::Integer) w
198198
end
199199

200200
GPUArrays.storage(a::JLArray) = a.data
201-
GPUArrays._dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N}
202-
GPUArrays._dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N}
203-
GPUArrays._dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1}
204-
GPUArrays._dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1}
205-
206-
GPUArrays._sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC
207-
GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC
208-
GPUArrays._sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR
209-
GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
210-
GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector
211-
GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
212-
213-
GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray
214-
GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray
215-
GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray
216-
GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
217-
GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray
218-
GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
219-
220-
GPUArrays._csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
221-
GPUArrays._csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
201+
GPUArrays.dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N}
202+
GPUArrays.dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N}
203+
GPUArrays.dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1}
204+
GPUArrays.dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1}
205+
206+
GPUArrays.sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC
207+
GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC
208+
GPUArrays.sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR
209+
GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
210+
GPUArrays.sparse_array_type(sa::JLSparseVector) = JLSparseVector
211+
GPUArrays.sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
212+
213+
GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
214+
GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
215+
GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
216+
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
217+
GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
218+
GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
219+
220+
GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
221+
GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
222222

223223
# conversion of untyped data to a typed Array
224224
function typed_data(x::JLArray{T}) where {T}
@@ -361,6 +361,23 @@ end
361361
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
362362
Base.size(x::JLSparseMatrixCSR) = x.dims
363363

364+
function GPUArrays._spadjoint(A::JLSparseMatrixCSR)
365+
Aᴴ = JLSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A)))
366+
JLSparseMatrixCSR(Aᴴ)
367+
end
368+
function GPUArrays._sptranspose(A::JLSparseMatrixCSR)
369+
Aᵀ = JLSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A)))
370+
JLSparseMatrixCSR(Aᵀ)
371+
end
372+
function _spadjoint(A::JLSparseMatrixCSC)
373+
Aᴴ = JLSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A)))
374+
JLSparseMatrixCSC(Aᴴ)
375+
end
376+
function _sptranspose(A::JLSparseMatrixCSC)
377+
Aᵀ = JLSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
378+
JLSparseMatrixCSC(Aᵀ)
379+
end
380+
364381
# idempotency
365382
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
366383

src/host/sparse.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ SparseArrays.getcolptr(S::AbstractGPUSparseMatrixCSC) = S.colPtr
2323

2424
Base.convert(T::Type{<:AbstractGPUSparseArray}, m::AbstractArray) = m isa T ? m : T(m)
2525

26-
_dense_array_type(sa::SparseVector) = SparseVector
27-
_dense_array_type(::Type{SparseVector}) = SparseVector
28-
_sparse_array_type(sa::SparseVector) = SparseVector
29-
_dense_vector_type(sa::AbstractSparseArray) = Vector
30-
_dense_vector_type(sa::AbstractArray) = Vector
31-
_dense_vector_type(::Type{<:AbstractSparseArray}) = Vector
32-
_dense_vector_type(::Type{<:AbstractArray}) = Vector
33-
_dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
34-
_dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC
35-
_sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
36-
37-
function _sparse_array_type(sa::AbstractGPUSparseArray) end
38-
function _dense_array_type(sa::AbstractGPUSparseArray) end
39-
function _coo_type(sa::AbstractGPUSparseArray) end
40-
_coo_type(::SA) where {SA<:AbstractGPUSparseMatrixCSC} = SA
26+
dense_array_type(sa::SparseVector) = SparseVector
27+
dense_array_type(::Type{SparseVector}) = SparseVector
28+
sparse_array_type(sa::SparseVector) = SparseVector
29+
dense_vector_type(sa::AbstractSparseArray) = Vector
30+
dense_vector_type(sa::AbstractArray) = Vector
31+
dense_vector_type(::Type{<:AbstractSparseArray}) = Vector
32+
dense_vector_type(::Type{<:AbstractArray}) = Vector
33+
dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
34+
dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC
35+
sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
36+
37+
function sparse_array_type(sa::AbstractGPUSparseArray) end
38+
function dense_array_type(sa::AbstractGPUSparseArray) end
39+
function coo_type(sa::AbstractGPUSparseArray) end
40+
coo_type(::SA) where {SA<:AbstractGPUSparseMatrixCSC} = SA
4141

4242
function _spadjoint end
4343
function _sptranspose end
@@ -908,8 +908,8 @@ end
908908
end
909909
## COV_EXCL_STOP
910910

911-
function _csc_type end
912-
function _csr_type end
911+
function csc_type end
912+
function csr_type end
913913

914914
# TODO: implement mapreducedim!
915915
function Base.mapreduce(f, op, A::AbstractGPUSparseMatrix; dims=:, init=nothing)

0 commit comments

Comments
 (0)