From 671d723fb254bc8534a844b48789c9dcb5d8aa1d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 16:37:42 -0500 Subject: [PATCH 1/2] add `reindexdims` --- src/SparseArrayKit.jl | 1 + src/base.jl | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/SparseArrayKit.jl b/src/SparseArrayKit.jl index 5493fa3..c287dba 100644 --- a/src/SparseArrayKit.jl +++ b/src/SparseArrayKit.jl @@ -7,6 +7,7 @@ using TupleTools export SparseArray export nonzero_pairs, nonzero_keys, nonzero_values, nonzero_length +export reindexdims, reindexdims! include("sparsearray.jl") include("base.jl") diff --git a/src/base.jl b/src/base.jl index c25e613..c613ca9 100644 --- a/src/base.jl +++ b/src/base.jl @@ -69,3 +69,27 @@ function Base.reshape(parent::SparseArray{T}, dims::Dims) where {T} end return child end + +@doc """ + reindexdims(A, p) + reindexdims!(C, A, p) + +Reindex the dimensions (axes) of array `A`. `p` is a tuple of integers specifying which indices are selected. +This is similar to `permutedims(!)`, but also allows both repeated and omitted integers. +The former boils down to a broadcasting along the diagonal, i.e. `C[i, i, j, k, ...] = A[i, j, k, ...]`, +while the latter signifies a reduction over the omitted index, i.e. `C[j, k, ...] = ∑_i A[i, j, k, ...]`. +""" reindexdims, reindexdims! + +function reindexdims(A::SparseArray, p::IndexTuple) + C = similar(A, TupleTools.getindices(size(A), p)) + return reindexdims!(C, A, p) +end +function reindexdims!(C::SparseArray{T, N}, A::SparseArray, p::IndexTuple{N}) where {T, N} + _zero!(C) + _sizehint!(C, nonzero_length(A)) + for (IA, vA) in nonzero_pairs(A) + IC = CartesianIndex(TupleTools.getindices(IA.I, p)) + increaseindex!(C, vA, IC) + end + return C +end From 0435d908f134c77ae9eac28c19dc1e36e6067dc9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 17:02:24 -0500 Subject: [PATCH 2/2] add some small tests --- test/basic.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 6a0bce7..5811757 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,6 +1,8 @@ module BasicTests + using SparseArrayKit using Test, TestExtras, LinearAlgebra, Random +using TupleTools #= generate a whole bunch of random contractions, compare with the dense result @@ -96,4 +98,19 @@ end @test size(SparseArray(I, 3, 8)) == (3, 8) end +@timedtestset "Index manipulations" begin + dims = (2, 3, 4) + A = randn_sparse(Float64, dims) + @test @constinferred(reindexdims(A, (2, 1, 3))) == permutedims(A, (2, 1, 3)) + + A_expanded = @constinferred reindexdims(A, (1, 1, 2, 3)) + @test size(A_expanded) == TupleTools.getindices(size(A), (1, 1, 2, 3)) + @test norm(A_expanded) ≈ norm(A) + @test reindexdims(A_expanded, (1, 3, 4)) == A + + A_reduced = @constinferred reindexdims(A, (1, 2)) + @test size(A_reduced) == TupleTools.getindices(size(A), (1, 2)) + @test Array(A_reduced) ≈ sum(Array(A); dims = 3) +end + end