Skip to content

Commit 7d9e008

Browse files
committed
make lastindex type stable by always returning RaggedEnd
1 parent d1c9ae7 commit 7d9e008

File tree

2 files changed

+122
-10
lines changed

2 files changed

+122
-10
lines changed

src/vector_of_array.jl

Lines changed: 110 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,16 @@ function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
402402
return lastindex(VA.u)
403403
end
404404

405+
# Always return RaggedEnd for type stability. Use dim=0 to indicate a plain index stored in offset.
406+
# _resolve_ragged_index and _column_indices handle the dim=0 case to extract the actual index value.
405407
@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
406408
if d == ndims(VA)
407-
return lastindex(VA.u)
409+
return RaggedEnd(0, Int(lastindex(VA.u)))
408410
elseif d < ndims(VA)
409-
isempty(VA.u) && return 0
410-
return RaggedEnd(Int(d))
411+
isempty(VA.u) && return RaggedEnd(0, 0)
412+
return RaggedEnd(Int(d), 0)
411413
else
412-
return 1
414+
return RaggedEnd(0, 1)
413415
end
414416
end
415417

@@ -534,13 +536,27 @@ end
534536
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
535537
findall(idx)
536538
end
539+
@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedEnd)
540+
# RaggedEnd with dim=0 means it's just a plain index stored in offset
541+
idx.dim == 0 ? idx.offset : idx
542+
end
537543

538544
@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
539545
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
540-
return lastindex(VA.u[col], idx.dim) + idx.offset
546+
if idx.dim == 0
547+
# Special case: dim=0 means the offset contains the actual index value
548+
return idx.offset
549+
else
550+
return lastindex(VA.u[col], idx.dim) + idx.offset
551+
end
541552
end
542553
@inline function _resolve_ragged_index(idx::RaggedRange, VA::AbstractVectorOfArray, col)
543-
stop_val = lastindex(VA.u[col], idx.dim) + idx.offset
554+
stop_val = if idx.dim == 0
555+
# dim == 0 is the sentinel for an already-resolved plain index stored in offset
556+
idx.offset
557+
else
558+
lastindex(VA.u[col], idx.dim) + idx.offset
559+
end
544560
return Base.range(idx.start; step = idx.step, stop = stop_val)
545561
end
546562
@inline function _resolve_ragged_index(idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
@@ -583,16 +599,100 @@ end
583599
@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs)
584600

585601
@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
586-
cols = last(I)
587-
prefix = Base.front(I)
602+
n = ndims(A)
603+
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
604+
if length(I) == n - 1
605+
raw_cols = last(I)
606+
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
607+
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
608+
lastindex(A.u) + raw_cols.offset
609+
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
610+
stop_val = lastindex(A.u) + raw_cols.offset
611+
Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val)
612+
else
613+
_column_indices(A, raw_cols)
614+
end
615+
prefix = Base.front(I)
616+
if cols isa Int
617+
resolved_prefix = _resolve_ragged_indices(prefix, A, cols)
618+
inner_nd = ndims(A.u[cols])
619+
n_missing = inner_nd - length(resolved_prefix)
620+
padded = if n_missing > 0
621+
if all(idx -> idx === Colon(), resolved_prefix)
622+
(resolved_prefix..., ntuple(_ -> Colon(), n_missing)...)
623+
else
624+
(resolved_prefix..., (lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)...,)
625+
end
626+
else
627+
resolved_prefix
628+
end
629+
return A.u[cols][padded...]
630+
else
631+
return VectorOfArray([
632+
begin
633+
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
634+
inner_nd = ndims(A.u[col])
635+
n_missing = inner_nd - length(resolved_prefix)
636+
padded = if n_missing > 0
637+
if all(idx -> idx === Colon(), resolved_prefix)
638+
(resolved_prefix..., ntuple(_ -> Colon(), n_missing)...)
639+
else
640+
(resolved_prefix..., (lastindex(A.u[col], length(resolved_prefix) + i) for i in 1:n_missing)...,)
641+
end
642+
else
643+
resolved_prefix
644+
end
645+
A.u[col][padded...]
646+
end for col in cols
647+
])
648+
end
649+
end
650+
651+
# Otherwise, use the full-length interpretation (last index is column selector; missing columns default to Colon()).
652+
if length(I) == n
653+
cols = last(I)
654+
prefix = Base.front(I)
655+
else
656+
cols = Colon()
657+
prefix = I
658+
end
588659
if cols isa Int
660+
if all(idx -> idx === Colon(), prefix)
661+
return A.u[cols]
662+
end
589663
resolved = _resolve_ragged_indices(prefix, A, cols)
590-
return A.u[cols][resolved...]
664+
inner_nd = ndims(A.u[cols])
665+
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
666+
return A.u[cols][padded...]
591667
else
592668
col_idxs = _column_indices(A, cols)
669+
# Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection
670+
if col_idxs isa RaggedEnd
671+
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
672+
elseif col_idxs isa RaggedRange
673+
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
674+
end
675+
# If we're selecting whole inner arrays (all leading indices are Colons),
676+
# keep the result as a VectorOfArray to match non-ragged behavior.
677+
if all(idx -> idx === Colon(), prefix)
678+
if col_idxs isa Int
679+
return A.u[col_idxs]
680+
else
681+
return VectorOfArray(A.u[col_idxs])
682+
end
683+
end
684+
# If col_idxs resolved to a single Int, handle it directly
685+
if col_idxs isa Int
686+
resolved = _resolve_ragged_indices(prefix, A, col_idxs)
687+
inner_nd = ndims(A.u[col_idxs])
688+
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
689+
return A.u[col_idxs][padded...]
690+
end
593691
vals = map(col_idxs) do col
594692
resolved = _resolve_ragged_indices(prefix, A, col)
595-
A.u[col][resolved...]
693+
inner_nd = ndims(A.u[col])
694+
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
695+
A.u[col][padded...]
596696
end
597697
return stack(vals)
598698
end

test/basic_indexing.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
173173
@test ragged[1:end, 1] == [1.0, 2.0]
174174
@test ragged[1:end, 2] == [3.0, 4.0, 5.0]
175175
@test ragged[1:end, 3] == [6.0, 7.0, 8.0, 9.0]
176+
@test ragged[:, end] == [6.0, 7.0, 8.0, 9.0]
177+
@test ragged[:, 2:end] == VectorOfArray(ragged.u[2:end])
176178

177179
ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
178180
@test ragged2[end, 1] == 4.0
@@ -188,6 +190,8 @@ ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
188190
@test ragged2[2:end, 1] == [2.0, 3.0, 4.0]
189191
@test ragged2[2:end, 2] == [6.0]
190192
@test ragged2[2:end, 3] == [8.0, 9.0]
193+
@test ragged2[:, end] == [7.0, 8.0, 9.0]
194+
@test ragged2[:, 2:end] == VectorOfArray(ragged2.u[2:end])
191195
@test ragged2[1:(end - 1), 1] == [1.0, 2.0, 3.0]
192196
@test ragged2[1:(end - 1), 2] == [5.0]
193197
@test ragged2[1:(end - 1), 3] == [7.0, 8.0]
@@ -209,6 +213,10 @@ u[1, :, 2] .= [1.0, 2.0, 3.0]
209213
# partial column selection by indices
210214
u[1, [1, 3], 2] .= [7.0, 9.0]
211215
@test u.u[2] == [7.0 2.0 9.0]
216+
# test scalar indexing with end
217+
@test u[1, 1, end] == u.u[end][1, 1]
218+
@test u[1, end, end] == u.u[end][1, end]
219+
@test u[1, 2:end, end] == vec(u.u[end][1, 2:end])
212220

213221
# 3D inner arrays (tensors) with ragged third dimension
214222
u = VectorOfArray([zeros(2, 1, n) for n in (2, 3)])
@@ -223,6 +231,10 @@ u[1:2, 1, [1, 3], 2] .= [1.0 3.0; 2.0 4.0]
223231
@test u.u[2][2, 1, 1] == 2.0
224232
@test u.u[2][1, 1, 3] == 3.0
225233
@test u.u[2][2, 1, 3] == 4.0
234+
@test u[:, :, end] == u.u[end]
235+
@test u[:, :, 2:end] == VectorOfArray(u.u[2:end])
236+
@test u[1, 1, end] == u.u[end][1, 1, end]
237+
@test u[end, 1, end] == u.u[end][end, 1, end]
226238

227239
# Test that views can be modified
228240
f3 = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0]])

0 commit comments

Comments
 (0)