Skip to content

Commit 5271d8b

Browse files
Merge pull request #504 from JoshuaLampert/fix-broadcasting-ragged-multi-d
Fix broadcasting and `view` for ragged `VectorOfArray`
2 parents 0c717ab + c062716 commit 5271d8b

File tree

2 files changed

+60
-18
lines changed

2 files changed

+60
-18
lines changed

src/vector_of_array.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
168168
VectorOfArray{T, N + 1, typeof(vec)}(vec)
169169
end
170170

171-
# allow multi-dimensional arrays as long as they're linearly indexed.
171+
# allow multi-dimensional arrays as long as they're linearly indexed.
172172
# currently restricted to arrays whose elements are all the same type
173173
function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
174174
@assert IndexStyle(typeof(array)) isa IndexLinear
@@ -675,13 +675,19 @@ function Base.view(A::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
675675
end
676676
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any, M}) where {M}
677677
@inline
678-
# Special handling for heterogeneous arrays when viewing a single column
679-
# The issue is that to_indices uses axes, which is based on the first element's size
680-
# For heterogeneous arrays, we need to use the actual size of the specific column
681-
if length(I) == 2 && I[1] == Colon() && I[2] isa Int
682-
@boundscheck checkbounds(A.u, I[2])
683-
# Use the actual size of the specific column instead of relying on axes/to_indices
684-
J = (Base.OneTo(length(A.u[I[2]])), I[2])
678+
# Generalized handling for heterogeneous arrays when the last index selects a column (Int)
679+
# The issue is that `to_indices` uses `axes(A)` which is based on the first element's size.
680+
# For heterogeneous arrays, use the actual axes of the specific selected inner array.
681+
if length(I) >= 1 && I[end] isa Int
682+
i = I[end]
683+
@boundscheck checkbounds(A.u, i)
684+
frontI = Base.front(I)
685+
# Normalize indices against the selected inner array's axes
686+
frontJ = to_indices(A.u[i], frontI)
687+
# Unalias indices and construct the full index tuple
688+
J = (map(j -> Base.unalias(A, j), frontJ)..., i)
689+
# Bounds check against the selected inner array to avoid relying on A's axes
690+
@boundscheck checkbounds(Bool, A.u[i], frontJ...) || throw(BoundsError(A, I))
685691
return SubArray(A, J)
686692
end
687693
J = map(i -> Base.unalias(A, i), to_indices(A, I))
@@ -711,10 +717,14 @@ function Base.checkbounds(
711717
end
712718
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
713719
checkbounds(Bool, VA.u, last(idx)) || return false
714-
for i in last(idx)
715-
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
720+
if last(idx) isa Int
721+
return checkbounds(Bool, VA.u[last(idx)], Base.front(idx)...)
722+
else
723+
for i in last(idx)
724+
checkbounds(Bool, VA.u[i], Base.front(idx)...) || return false
725+
end
726+
return true
716727
end
717-
return true
718728
end
719729
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
720730
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
@@ -950,13 +960,13 @@ end
950960
# make vectorofarrays broadcastable so they aren't collected
951961
Broadcast.broadcastable(x::AbstractVectorOfArray) = x
952962

953-
# recurse through broadcast arguments and return a parent array for
963+
# recurse through broadcast arguments and return a parent array for
954964
# the first VoA or DiffEqArray in the bc arguments
955965
function find_VoA_parent(args)
956966
arg = Base.first(args)
957967
if arg isa AbstractDiffEqArray
958-
# if first(args) is a DiffEqArray, use the underlying
959-
# field `u` of DiffEqArray as a parent array.
968+
# if first(args) is a DiffEqArray, use the underlying
969+
# field `u` of DiffEqArray as a parent array.
960970
return arg.u
961971
elseif arg isa AbstractVectorOfArray
962972
return parent(arg)
@@ -975,7 +985,7 @@ end
975985
map(1:N) do i
976986
copy(unpack_voa(bc, i))
977987
end
978-
else # if parent isa AbstractArray
988+
else # if parent isa AbstractArray
979989
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
980990
copy(unpack_voa(bc, i))
981991
end

test/basic_indexing.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,38 @@ f2 = VectorOfArray([[1.0, 2.0], [3.0]])
162162
@test collect(view(f2, :, 1)) == f2[:, 1]
163163
@test collect(view(f2, :, 2)) == f2[:, 2]
164164

165+
# Broadcasting of heterogeneous arrays (issue #454)
166+
u = VectorOfArray([[1.0], [2.0, 3.0]])
167+
@test length(view(u, :, 1)) == 1
168+
@test length(view(u, :, 2)) == 2
169+
# broadcast assignment into selected column (last index Int)
170+
u[:, 2] .= [10.0, 11.0]
171+
@test u.u[2] == [10.0, 11.0]
172+
173+
# 2D inner arrays (matrices) with ragged second dimension
174+
u = VectorOfArray([zeros(1, n) for n in (2, 3)])
175+
@test length(view(u, 1, :, 1)) == 2
176+
@test length(view(u, 1, :, 2)) == 3
177+
u[1, :, 2] .= [1.0, 2.0, 3.0]
178+
@test u.u[2] == [1.0 2.0 3.0]
179+
# partial column selection by indices
180+
u[1, [1, 3], 2] .= [7.0, 9.0]
181+
@test u.u[2] == [7.0 2.0 9.0]
182+
183+
# 3D inner arrays (tensors) with ragged third dimension
184+
u = VectorOfArray([zeros(2, 1, n) for n in (2, 3)])
185+
@test size(view(u, :, :, :, 1)) == (2, 1, 2)
186+
@test size(view(u, :, :, :, 2)) == (2, 1, 3)
187+
# assign into a slice of the second inner array using last index Int
188+
u[2, 1, :, 2] .= [7.0, 8.0, 9.0]
189+
@test vec(u.u[2][2, 1, :]) == [7.0, 8.0, 9.0]
190+
# check mixed slicing with range on front dims
191+
u[1:2, 1, [1, 3], 2] .= [1.0 3.0; 2.0 4.0]
192+
@test u.u[2][1, 1, 1] == 1.0
193+
@test u.u[2][2, 1, 1] == 2.0
194+
@test u.u[2][1, 1, 3] == 3.0
195+
@test u.u[2][2, 1, 3] == 4.0
196+
165197
# Test that views can be modified
166198
f3 = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0]])
167199
v = view(f3, :, 2)
@@ -259,14 +291,14 @@ a[1:8]
259291
a[[1, 3, 8]]
260292

261293
####################################################################
262-
# test when VectorOfArray is constructed from a linearly indexed
294+
# test when VectorOfArray is constructed from a linearly indexed
263295
# multidimensional array of arrays
264296
####################################################################
265297

266298
u_matrix = VectorOfArray([[1, 2] for i in 1:2, j in 1:3])
267299
u_vector = VectorOfArray([[1, 2] for i in 1:6])
268300

269-
# test broadcasting
301+
# test broadcasting
270302
function foo!(u)
271303
@. u += 2 * u * abs(u)
272304
return u
@@ -281,7 +313,7 @@ foo!(u_vector)
281313
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
282314
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))
283315

284-
# test efficiency
316+
# test efficiency
285317
num_allocs = @allocations foo!(u_matrix)
286318
@test num_allocs == 0
287319

0 commit comments

Comments
 (0)