Skip to content

Commit 57feb5d

Browse files
authored
findall support through AcceleratedKernels.jl (#533)
1 parent ad72a8a commit 57feb5d

File tree

5 files changed

+69
-1
lines changed

5 files changed

+69
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "2.3.0"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
8+
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
89
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
1011
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -31,6 +32,7 @@ oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
3132

3233
[compat]
3334
AbstractFFTs = "1.5.0"
35+
AcceleratedKernels = "0.4.3"
3436
Adapt = "4"
3537
CEnum = "0.4, 0.5"
3638
ExprTools = "0.1"

src/accumulate.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Base.accumulate!(op, B::oneArray, A::oneArray; init = zero(eltype(A)), kwargs...) =
2+
AK.accumulate!(op, B, A, oneAPIBackend(); init, kwargs...)
3+
4+
Base.accumulate(op, A::oneArray; init = zero(eltype(A)), kwargs...) =
5+
AK.accumulate(op, A, oneAPIBackend(); init, kwargs...)
6+
7+
Base.cumsum(src::oneArray; kwargs...) = AK.cumsum(src, oneAPIBackend(); kwargs...)
8+
Base.cumprod(src::oneArray; kwargs...) = AK.cumprod(src, oneAPIBackend(); kwargs...)

src/indexing.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
Base.to_index(::oneArray, I::AbstractArray{Bool}) = findall(I)
2+
3+
if VERSION >= v"1.11.0-DEV.1157"
4+
Base.to_indices(x::oneArray, I::Tuple{AbstractArray{Bool}}) =
5+
(Base.to_index(x, I[1]),)
6+
end
7+
8+
function _ker!(ys, bools, indices)
9+
i = get_global_id()
10+
11+
@inbounds if i length(bools) && bools[i]
12+
ii = CartesianIndices(bools)[i]
13+
b = indices[i] # new position
14+
ys[b] = ii
15+
end
16+
return
17+
end
18+
19+
function Base.findall(bools::oneArray{Bool})
20+
I = keytype(bools)
21+
22+
indices = cumsum(reshape(bools, prod(size(bools))))
23+
oneL0.synchronize()
24+
25+
n = isempty(indices) ? 0 : @allowscalar indices[end]
26+
27+
ys = oneArray{I}(undef, n)
28+
29+
if n > 0
30+
@oneapi items = length(bools) _ker!(ys, bools, indices)
31+
end
32+
oneL0.synchronize()
33+
unsafe_free!(indices)
34+
35+
return ys
36+
end

src/oneAPI.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ export SYCL
5858
include("../lib/mkl/oneMKL.jl")
5959
export oneMKL
6060
end
61-
61+
import AcceleratedKernels as AK
6262
# integrations and specialized functionality
6363
include("broadcast.jl")
6464
include("mapreduce.jl")
@@ -68,6 +68,8 @@ include("utils.jl")
6868

6969
include("oneAPIKernels.jl")
7070
import .oneAPIKernels: oneAPIBackend
71+
include("accumulate.jl")
72+
include("indexing.jl")
7173
export oneAPIBackend
7274

7375
function __init__()

test/indexing.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Test
2+
using oneAPI
3+
4+
@testset "findall" begin
5+
bools1d = oneArray([true, false, true, false, true])
6+
@test Array(findall(bools1d)) == findall(Bool[true, false, true, false, true])
7+
8+
bools2d = oneArray(Bool[true false; false true; true false])
9+
@test Array(findall(bools2d)) == findall(Bool[true false; false true; true false])
10+
11+
all_false = oneArray(fill(false, 4))
12+
@test Array(findall(all_false)) == Int[]
13+
14+
all_true = oneArray(fill(true, 3, 2))
15+
@test Array(findall(all_true)) == findall(fill(true, 3, 2))
16+
17+
data = oneArray(collect(1:6))
18+
mask = oneArray(Bool[true, false, true, false, false, true])
19+
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
20+
end

0 commit comments

Comments
 (0)