Skip to content
48 changes: 36 additions & 12 deletions src/KernelAbstractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -532,40 +532,64 @@ get_backend(::Array) = CPU()
Adapt.adapt_storage(::CPU, a::Array) = a

"""
allocate(::Backend, Type, dims...)::AbstractArray
allocate(::Backend, Type, dims...; unified=false)::AbstractArray

Allocate a storage array appropriate for the computational backend.
Allocate a storage array appropriate for the computational backend. `unified`
allocates an array using unified memory if the backend supports it. Use
[`supports_unified`](@ref) to determine whether it is supported by a backend.

!!! note
Backend implementations **must** implement `allocate(::NewBackend, T, dims::Tuple)`
"""
allocate(backend::Backend, T::Type, dims...) = allocate(backend, T, dims)
allocate(backend::Backend, T::Type, dims::Tuple) = throw(MethodError(allocate, (backend, T, dims)))
allocate(backend::Backend, T::Type, dims...; kwargs...) = allocate(backend, T, dims; kwargs...)
function allocate(backend::Backend, T::Type, dims::Tuple; unified::Union{Nothing, Bool} = nothing)
return if isnothing(unified)
throw(MethodError(allocate, (backend, T, dims)))
elseif unified
throw(ArgumentError("`$(typeof(backend))` either does not support unified memory or it has not yet defined `allocate(backend::$backend, T::Type, dims::Tuple; unified::Bool)`"))
else
allocate(backend, T, dims)
end
end


"""
zeros(::Backend, Type, dims...)::AbstractArray
zeros(::Backend, Type, dims...; unified=false)::AbstractArray

Allocate a storage array appropriate for the computational backend filled with zeros.
`unified` allocates an array using unified memory if the backend supports it.
"""
zeros(backend::Backend, T::Type, dims...) = zeros(backend, T, dims)
function zeros(backend::Backend, ::Type{T}, dims::Tuple) where {T}
data = allocate(backend, T, dims...)
zeros(backend::Backend, T::Type, dims...; kwargs...) = zeros(backend, T, dims; kwargs...)
function zeros(backend::Backend, ::Type{T}, dims::Tuple; kwargs...) where {T}
data = allocate(backend, T, dims...; kwargs...)
fill!(data, zero(T))
return data
end

"""
ones(::Backend, Type, dims...)::AbstractArray
ones(::Backend, Type, dims...; unified=false)::AbstractArray

Allocate a storage array appropriate for the computational backend filled with ones.
`unified` allocates an array using unified memory if the backend supports it.
"""
ones(backend::Backend, T::Type, dims...) = ones(backend, T, dims)
function ones(backend::Backend, ::Type{T}, dims::Tuple) where {T}
data = allocate(backend, T, dims)
ones(backend::Backend, T::Type, dims...; kwargs...) = ones(backend, T, dims; kwargs...)
function ones(backend::Backend, ::Type{T}, dims::Tuple; kwargs...) where {T}
data = allocate(backend, T, dims; kwargs...)
fill!(data, one(T))
return data
end

"""
supports_unified(::Backend)::Bool

Returns whether unified memory arrays are supported by the backend.

!!! note
Backend implementations **must** implement this function
only if they **do** support unified memory.
"""
supports_unified(::Backend) = false

"""
supports_atomics(::Backend)::Bool

Expand Down
7 changes: 7 additions & 0 deletions test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
backendT = typeof(backend).name.wrapper # To look through CUDABackend{true, false}
@test backend isa backendT

unified = KernelAbstractions.supports_unified(backend)
@test unified isa Bool
U = allocate(backend, Float32, 5; unified)
if unified
@test U[3] isa Float32
end

x = allocate(backend, Float32, 5)
A = allocate(backend, Float32, 5, 5)
@test @inferred(KernelAbstractions.get_backend(A)) isa backendT
Expand Down
Loading