Skip to content

Commit fb9fc6a

Browse files
committed
Create derivative sparse arrays
1 parent 46f88e8 commit fb9fc6a

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

test/runtests.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ include("testsuite.jl")
66
const init_code = quote
77
using Test, JLArrays, SparseArrays
88

9+
sparse_types(::Type{<:JLArray}) = (JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR)
10+
sparse_types(::Type{<:Array}) = (SparseVector, SparseMatrixCSC)
11+
912
include("testsuite.jl")
1013

1114
# Disable Float16-related tests until JuliaGPU/KernelAbstractions#600 is resolved
@@ -16,11 +19,7 @@ const init_code = quote
1619
end
1720

1821
custom_tests = Dict{String, Expr}()
19-
for AT in (:JLArray, :Array), name in filter(n->n != "sparse", keys(TestSuite.tests))
20-
custom_tests["$(AT)/$name"] = :(TestSuite.tests[$name]($AT))
21-
end
22-
23-
for AT in (:JLSparseMatrixCSR, :JLSparseMatrixCSC, :JLSparseVector, :SparseMatrixCSC, :SparseVector), name in ["sparse"]
22+
for AT in (:JLArray, :Array), name in keys(TestSuite.tests)
2423
custom_tests["$(AT)/$name"] = :(TestSuite.tests[$name]($AT))
2524
end
2625

test/testsuite.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ supported_eltypes() = (Int16, Int32, Int64,
6363
ComplexF16, ComplexF32, ComplexF64,
6464
Complex{Int16}, Complex{Int32}, Complex{Int64})
6565

66+
# derived sparse types that are supported by the array type
67+
68+
sparse_types(::Type{AT}) where {AT} = ()
69+
6670
# some convenience predicates for filtering test eltypes
6771
isrealtype(T) = T <: Real
6872
iscomplextype(T) = T <: Complex

test/testsuite/sparse.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
@testsuite "sparse" (AT, eltypes)->begin
2-
if AT <: AbstractSparseVector
3-
broadcasting_vector(AT, eltypes)
4-
elseif AT <: AbstractSparseMatrix
5-
broadcasting_matrix(AT, eltypes)
6-
mapreduce_matrix(AT, eltypes)
2+
sparse_ATs = sparse_types(AT)
3+
for sparse_AT in sparse_ATs
4+
if sparse_AT <: AbstractSparseVector
5+
broadcasting_vector(sparse_AT, eltypes)
6+
elseif sparse_AT <: AbstractSparseMatrix
7+
broadcasting_matrix(sparse_AT, eltypes)
8+
mapreduce_matrix(sparse_AT, eltypes)
9+
end
710
end
811
end
912

0 commit comments

Comments
 (0)