Skip to content

Commit 9a89a75

Browse files
Merge pull request #501 from JoshuaLampert/lastindex-ragged
Implement `lastindex` for ragged arrays
2 parents 42d1762 + 5cf0f66 commit 9a89a75

File tree

2 files changed

+316
-13
lines changed

2 files changed

+316
-13
lines changed

src/vector_of_array.jl

Lines changed: 274 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -388,35 +388,88 @@ end
388388
@inline Base.IteratorSize(::Type{<:AbstractVectorOfArray}) = Base.HasLength()
389389
@inline Base.first(VA::AbstractVectorOfArray) = first(VA.u)
390390
@inline Base.last(VA::AbstractVectorOfArray) = last(VA.u)
391-
function Base.firstindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
391+
function Base.firstindex(VA::AbstractVectorOfArray{T, N, A}) where {T, N, A}
392392
N > 1 && Base.depwarn(
393393
"Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ",
394394
:firstindex)
395395
return firstindex(VA.u)
396396
end
397397

398-
function Base.lastindex(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A}
399-
N > 1 && Base.depwarn(
398+
function Base.lastindex(VA::AbstractVectorOfArray{T, N, A}) where {T, N, A}
399+
N > 1 && Base.depwarn(
400400
"Linear indexing of `AbstractVectorOfArray` is deprecated. Change `A[i]` to `A.u[i]` ",
401401
:lastindex)
402402
return lastindex(VA.u)
403403
end
404404

405+
# Always return RaggedEnd for type stability. Use dim=0 to indicate a plain index stored in offset.
406+
# _resolve_ragged_index and _column_indices handle the dim=0 case to extract the actual index value.
407+
@inline function Base.lastindex(VA::AbstractVectorOfArray, d::Integer)
408+
if d == ndims(VA)
409+
return RaggedEnd(0, Int(lastindex(VA.u)))
410+
elseif d < ndims(VA)
411+
isempty(VA.u) && return RaggedEnd(0, 0)
412+
return RaggedEnd(Int(d), 0)
413+
else
414+
return RaggedEnd(0, 1)
415+
end
416+
end
417+
405418
Base.getindex(A::AbstractVectorOfArray, I::Int) = A.u[I]
406419
Base.getindex(A::AbstractVectorOfArray, I::AbstractArray{Int}) = A.u[I]
407420
Base.getindex(A::AbstractDiffEqArray, I::Int) = A.u[I]
408421
Base.getindex(A::AbstractDiffEqArray, I::AbstractArray{Int}) = A.u[I]
409422

410-
@deprecate Base.getindex(VA::AbstractVectorOfArray{T,N,A}, I::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
423+
@deprecate Base.getindex(VA::AbstractVectorOfArray{T, N, A},
424+
I::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
411425

412-
@deprecate Base.getindex(VA::AbstractVectorOfArray{T,N,A}, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
426+
@deprecate Base.getindex(VA::AbstractVectorOfArray{T, N, A},
427+
I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
413428

414-
@deprecate Base.getindex(VA::AbstractDiffEqArray{T,N,A}, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
429+
@deprecate Base.getindex(VA::AbstractDiffEqArray{T, N, A},
430+
I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[I] false
415431

416-
@deprecate Base.getindex(VA::AbstractDiffEqArray{T,N,A}, i::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} VA.u[i] false
432+
@deprecate Base.getindex(VA::AbstractDiffEqArray{T, N, A},
433+
i::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} VA.u[i] false
417434

418435
__parameterless_type(T) = Base.typename(T).wrapper
419436

437+
# `end` support for ragged inner arrays
438+
# Use runtime fields instead of type parameters for type stability
439+
struct RaggedEnd
440+
dim::Int
441+
offset::Int
442+
end
443+
RaggedEnd(dim::Int) = RaggedEnd(dim, 0)
444+
445+
Base.:+(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset + Int(n))
446+
Base.:-(re::RaggedEnd, n::Integer) = RaggedEnd(re.dim, re.offset - Int(n))
447+
Base.:+(n::Integer, re::RaggedEnd) = re + n
448+
449+
struct RaggedRange
450+
dim::Int
451+
start::Int
452+
step::Int
453+
offset::Int
454+
end
455+
456+
Base.:(:)(stop::RaggedEnd) = RaggedRange(stop.dim, 1, 1, stop.offset)
457+
function Base.:(:)(start::Integer, stop::RaggedEnd)
458+
RaggedRange(stop.dim, Int(start), 1, stop.offset)
459+
end
460+
function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd)
461+
RaggedRange(stop.dim, Int(start), Int(step), stop.offset)
462+
end
463+
464+
@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
465+
length(VA.u) <= 1 && return false
466+
first_size = size(VA.u[1], d)
467+
@inbounds for idx in 2:length(VA.u)
468+
size(VA.u[idx], d) == first_size || return true
469+
end
470+
return false
471+
end
472+
420473
Base.@propagate_inbounds function _getindex(
421474
A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
422475
A.u[I]
@@ -487,11 +540,206 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
487540
return getindex(A, all_variable_symbols(A), args...)
488541
end
489542

543+
@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ?
544+
eachindex(VA.u) : idx
545+
@inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool})
546+
findall(idx)
547+
end
548+
@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedEnd)
549+
# RaggedEnd with dim=0 means it's just a plain index stored in offset
550+
idx.dim == 0 ? idx.offset : idx
551+
end
552+
553+
@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
554+
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
555+
if idx.dim == 0
556+
# Special case: dim=0 means the offset contains the actual index value
557+
return idx.offset
558+
else
559+
return lastindex(VA.u[col], idx.dim) + idx.offset
560+
end
561+
end
562+
@inline function _resolve_ragged_index(idx::RaggedRange, VA::AbstractVectorOfArray, col)
563+
stop_val = if idx.dim == 0
564+
# dim == 0 is the sentinel for an already-resolved plain index stored in offset
565+
idx.offset
566+
else
567+
lastindex(VA.u[col], idx.dim) + idx.offset
568+
end
569+
return Base.range(idx.start; step = idx.step, stop = stop_val)
570+
end
571+
@inline function _resolve_ragged_index(
572+
idx::AbstractRange{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
573+
return Base.range(_resolve_ragged_index(first(idx), VA, col); step = step(idx),
574+
stop = _resolve_ragged_index(last(idx), VA, col))
575+
end
576+
@inline function _resolve_ragged_index(idx::Base.Slice, VA::AbstractVectorOfArray, col)
577+
return Base.Slice(_resolve_ragged_index(idx.indices, VA, col))
578+
end
579+
@inline function _resolve_ragged_index(idx::CartesianIndex, VA::AbstractVectorOfArray, col)
580+
return CartesianIndex(_resolve_ragged_indices(Tuple(idx), VA, col)...)
581+
end
582+
@inline function _resolve_ragged_index(
583+
idx::AbstractArray{<:RaggedEnd}, VA::AbstractVectorOfArray, col)
584+
return map(i -> _resolve_ragged_index(i, VA, col), idx)
585+
end
586+
@inline function _resolve_ragged_index(
587+
idx::AbstractArray{<:RaggedRange}, VA::AbstractVectorOfArray, col)
588+
return map(i -> _resolve_ragged_index(i, VA, col), idx)
589+
end
590+
@inline function _resolve_ragged_index(idx::AbstractArray, VA::AbstractVectorOfArray, col)
591+
return _has_ragged_end(idx) ? map(i -> _resolve_ragged_index(i, VA, col), idx) : idx
592+
end
593+
594+
@inline function _resolve_ragged_indices(idxs::Tuple, VA::AbstractVectorOfArray, col)
595+
map(i -> _resolve_ragged_index(i, VA, col), idxs)
596+
end
597+
598+
@inline function _has_ragged_end(x)
599+
x isa RaggedEnd && return true
600+
x isa RaggedRange && return true
601+
x isa Base.Slice && return _has_ragged_end(x.indices)
602+
x isa CartesianIndex && return _has_ragged_end(Tuple(x))
603+
x isa AbstractRange && return eltype(x) <: Union{RaggedEnd, RaggedRange}
604+
if x isa AbstractArray
605+
el = eltype(x)
606+
return el <: Union{RaggedEnd, RaggedRange} ||
607+
(el === Any && any(_has_ragged_end, x))
608+
end
609+
x isa Tuple && return any(_has_ragged_end, x)
610+
return false
611+
end
612+
@inline _has_ragged_end(x, xs...) = _has_ragged_end(x) || _has_ragged_end(xs)
613+
614+
@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
615+
n = ndims(A)
616+
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
617+
if length(I) == n - 1
618+
raw_cols = last(I)
619+
# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
620+
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
621+
lastindex(A.u) + raw_cols.offset
622+
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
623+
stop_val = lastindex(A.u) + raw_cols.offset
624+
Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val)
625+
else
626+
_column_indices(A, raw_cols)
627+
end
628+
prefix = Base.front(I)
629+
if cols isa Int
630+
resolved_prefix = _resolve_ragged_indices(prefix, A, cols)
631+
inner_nd = ndims(A.u[cols])
632+
n_missing = inner_nd - length(resolved_prefix)
633+
padded = if n_missing > 0
634+
if all(idx -> idx === Colon(), resolved_prefix)
635+
(resolved_prefix..., ntuple(_ -> Colon(), n_missing)...)
636+
else
637+
(resolved_prefix...,
638+
(lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)...)
639+
end
640+
else
641+
resolved_prefix
642+
end
643+
return A.u[cols][padded...]
644+
else
645+
return VectorOfArray([begin
646+
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
647+
inner_nd = ndims(A.u[col])
648+
n_missing = inner_nd - length(resolved_prefix)
649+
padded = if n_missing > 0
650+
if all(idx -> idx === Colon(), resolved_prefix)
651+
(resolved_prefix...,
652+
ntuple(_ -> Colon(), n_missing)...)
653+
else
654+
(resolved_prefix...,
655+
(lastindex(A.u[col],
656+
length(resolved_prefix) + i) for i in 1:n_missing)...)
657+
end
658+
else
659+
resolved_prefix
660+
end
661+
A.u[col][padded...]
662+
end
663+
for col in cols])
664+
end
665+
end
666+
667+
# Otherwise, use the full-length interpretation (last index is column selector; missing columns default to Colon()).
668+
if length(I) == n
669+
cols = last(I)
670+
prefix = Base.front(I)
671+
else
672+
cols = Colon()
673+
prefix = I
674+
end
675+
if cols isa Int
676+
if all(idx -> idx === Colon(), prefix)
677+
return A.u[cols]
678+
end
679+
resolved = _resolve_ragged_indices(prefix, A, cols)
680+
inner_nd = ndims(A.u[cols])
681+
padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
682+
return A.u[cols][padded...]
683+
else
684+
col_idxs = _column_indices(A, cols)
685+
# Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection
686+
if col_idxs isa RaggedEnd
687+
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
688+
elseif col_idxs isa RaggedRange
689+
col_idxs = _resolve_ragged_index(col_idxs, A, 1)
690+
end
691+
# If we're selecting whole inner arrays (all leading indices are Colons),
692+
# keep the result as a VectorOfArray to match non-ragged behavior.
693+
if all(idx -> idx === Colon(), prefix)
694+
if col_idxs isa Int
695+
return A.u[col_idxs]
696+
else
697+
return VectorOfArray(A.u[col_idxs])
698+
end
699+
end
700+
# If col_idxs resolved to a single Int, handle it directly
701+
if col_idxs isa Int
702+
resolved = _resolve_ragged_indices(prefix, A, col_idxs)
703+
inner_nd = ndims(A.u[col_idxs])
704+
padded = (
705+
resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
706+
return A.u[col_idxs][padded...]
707+
end
708+
vals = map(col_idxs) do col
709+
resolved = _resolve_ragged_indices(prefix, A, col)
710+
inner_nd = ndims(A.u[col])
711+
padded = (
712+
resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...)
713+
A.u[col][padded...]
714+
end
715+
return stack(vals)
716+
end
717+
end
718+
719+
@inline function _checkbounds_ragged(::Type{Bool}, VA::AbstractVectorOfArray, idxs...)
720+
cols = _column_indices(VA, last(idxs))
721+
prefix = Base.front(idxs)
722+
if cols isa Int
723+
resolved = _resolve_ragged_indices(prefix, VA, cols)
724+
return checkbounds(Bool, VA.u, cols) && checkbounds(Bool, VA.u[cols], resolved...)
725+
else
726+
for col in cols
727+
resolved = _resolve_ragged_indices(prefix, VA, col)
728+
checkbounds(Bool, VA.u, col) || return false
729+
checkbounds(Bool, VA.u[col], resolved...) || return false
730+
end
731+
return true
732+
end
733+
end
734+
490735
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
491736
symtype = symbolic_type(_arg)
492737
elsymtype = symbolic_type(eltype(_arg))
493738

494739
if symtype == NotSymbolic() && elsymtype == NotSymbolic()
740+
if _has_ragged_end(_arg, args...)
741+
return _ragged_getindex(A, _arg, args...)
742+
end
495743
if _arg isa Union{Tuple, AbstractArray} &&
496744
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
497745
_getindex(A, symtype, elsymtype, _arg, args...)
@@ -523,25 +771,32 @@ Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}
523771
VA.u[I] = v
524772
end
525773

526-
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) = Base.setindex!(VA.u, v, I)
527-
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::Int) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(VA.u, v, I) false
774+
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Int) = Base.setindex!(
775+
VA.u, v, I)
776+
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v,
777+
I::Int) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
778+
VA.u, v, I) false
528779

529780
Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v,
530781
::Colon, I::Colon) where {T, N}
531782
VA.u[I] = v
532783
end
533784

534-
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) = Base.setindex!(VA.u, v, I)
535-
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::Colon) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
785+
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::Colon) = Base.setindex!(
786+
VA.u, v, I)
787+
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v,
788+
I::Colon) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
536789
VA.u, v, I) false
537790

538791
Base.@propagate_inbounds function Base.setindex!(VA::AbstractVectorOfArray{T, N}, v,
539792
::Colon, I::AbstractArray{Int}) where {T, N}
540793
VA.u[I] = v
541794
end
542795

543-
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) = Base.setindex!(VA.u, v, I)
544-
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T,N,A}, v, I::AbstractArray{Int}) where {T,N,A<:Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
796+
Base.@propagate_inbounds Base.setindex!(VA::AbstractVectorOfArray, v, I::AbstractArray{Int}) = Base.setindex!(
797+
VA.u, v, I)
798+
@deprecate Base.setindex!(VA::AbstractVectorOfArray{T, N, A}, v,
799+
I::AbstractArray{Int}) where {T, N, A <: Union{AbstractArray, AbstractVectorOfArray}} Base.setindex!(
545800
VA, v, :, I) false
546801

547802
Base.@propagate_inbounds function Base.setindex!(
@@ -710,12 +965,18 @@ Base.ndims(::Type{<:AbstractVectorOfArray{T, N}}) where {T, N} = N
710965
function Base.checkbounds(
711966
::Type{Bool}, VA::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
712967
idxs...) where {T, N}
968+
if _has_ragged_end(idxs...)
969+
return _checkbounds_ragged(Bool, VA, idxs...)
970+
end
713971
if length(idxs) == 2 && (idxs[1] == Colon() || idxs[1] == 1)
714972
return checkbounds(Bool, VA.u, idxs[2])
715973
end
716974
return checkbounds(Bool, VA.u, idxs...)
717975
end
718976
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
977+
if _has_ragged_end(idx...)
978+
return _checkbounds_ragged(Bool, VA, idx...)
979+
end
719980
checkbounds(Bool, VA.u, last(idx)) || return false
720981
if last(idx) isa Int
721982
return checkbounds(Bool, VA.u[last(idx)], Base.front(idx)...)

0 commit comments

Comments
 (0)