Skip to content

Commit eb2e255

Browse files
Merge pull request #506 from JoshuaLampert/JET-tests
Add JET tests
2 parents 9a89a75 + 4828dc7 commit eb2e255

File tree

7 files changed

+34
-8
lines changed

7 files changed

+34
-8
lines changed

.github/workflows/Tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ jobs:
2929
group:
3030
- "Core"
3131
- "Downstream"
32+
- "JET"
3233
uses: "SciML/.github/.github/workflows/tests.yml@v1"
3334
with:
3435
group: "${{ matrix.group }}"

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
3-
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
43
version = "3.41.0"
4+
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -48,6 +48,7 @@ DocStringExtensions = "0.9.3"
4848
FastBroadcast = "0.3.5"
4949
ForwardDiff = "0.10.38, 1"
5050
GPUArraysCore = "0.2"
51+
JET = "0.9, 0.11"
5152
KernelAbstractions = "0.9.36"
5253
LinearAlgebra = "1.10"
5354
Measurements = "2.11"
@@ -76,6 +77,7 @@ julia = "1.10"
7677
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7778
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
7879
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
80+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7981
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
8082
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
8183
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
@@ -93,4 +95,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
9395
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9496

9597
[targets]
96-
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
98+
test = ["Aqua", "FastBroadcast", "ForwardDiff", "JET", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]

src/named_array_partition.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
NamedArrayPartition(; kwargs...)
3-
NamedArrayPartition(x::NamedTuple)
3+
NamedArrayPartition(x::NamedTuple)
44
55
Similar to an `ArrayPartition` but the individual arrays can be accessed via the
66
constructor-specified names. However, unlike `ArrayPartition`, each individual array
@@ -22,7 +22,7 @@ function NamedArrayPartition(x::NamedTuple)
2222
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
2323
end
2424

25-
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
25+
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
2626
# fields except through `getfield` and accessor functions.
2727
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)
2828

@@ -53,7 +53,7 @@ end
5353
function Base.similar(
5454
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
5555
NamedArrayPartition(
56-
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
56+
similar(getfield(A, :array_partition), T, S, R...), getfield(A, :names_to_indices))
5757
end
5858

5959
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
@@ -68,7 +68,7 @@ function Base.getproperty(x::NamedArrayPartition, s::Symbol)
6868
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
6969
end
7070

71-
# this enables x.s = some_array.
71+
# this enables x.s = some_array.
7272
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
7373
index = getproperty(getfield(x, :names_to_indices), s)
7474
ArrayPartition(x).x[index] .= v

src/utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ function recursivefill!(b::AbstractArray{T, N},
119119
a::T2) where {T <: StaticArraysCore.SArray,
120120
T2 <: Union{Number, Bool}, N}
121121
@inbounds for i in eachindex(b)
122-
b[i] = fill(a, typeof(b[i]))
122+
# Preserve static array shape while replacing all entries with the scalar
123+
b[i] = map(_ -> a, b[i])
123124
end
124125
end
125126

@@ -128,7 +129,8 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N},
128129
T2 <: Union{Number, Bool}, N}
129130
@inbounds for b in bs, i in eachindex(b)
130131

131-
b[i] = fill(a, typeof(b[i]))
132+
# Preserve static array shape while replacing all entries with the scalar
133+
b[i] = map(_ -> a, b[i])
132134
end
133135
end
134136

src/vector_of_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,8 @@ function Base.view(A::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
924924
J = map(i -> Base.unalias(A, i), to_indices(A, I))
925925
elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1)
926926
J = map(i -> Base.unalias(A, i), to_indices(A, Base.tail(I)))
927+
else
928+
J = map(i -> Base.unalias(A, i), to_indices(A, I))
927929
end
928930
@boundscheck checkbounds(A, J...)
929931
SubArray(A, J)
@@ -1200,6 +1202,7 @@ end
12001202

12011203
struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
12021204
VectorOfArrayStyle{N}(::Val{N}) where {N} = VectorOfArrayStyle{N}()
1205+
VectorOfArrayStyle(::Val{N}) where {N} = VectorOfArrayStyle{N}()
12031206

12041207
# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
12051208
Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a

test/jet_tests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using JET, Test, RecursiveArrayTools
2+
3+
# Get all reports first
4+
result = JET.report_package(RecursiveArrayTools; target_modules = (RecursiveArrayTools,))
5+
reports = JET.get_reports(result)
6+
7+
# Filter out similar_type inference errors from StaticArraysCore
8+
filtered_reports = filter(reports) do report
9+
s = string(report)
10+
!(occursin("similar_type", s) && occursin("StaticArraysCore", s))
11+
end
12+
13+
# Check if there are any non-filtered errors
14+
@test isempty(filtered_reports)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,8 @@ end
5656
@time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl")
5757
@time @safetestset "ArrayPartition GPU" include("gpu/arraypartition_gpu.jl")
5858
end
59+
60+
if GROUP == "JET" || GROUP == "All"
61+
@time @safetestset "JET Tests" include("jet_tests.jl")
62+
end
5963
end

0 commit comments

Comments
 (0)