Skip to content

Commit 0d0121a

Browse files
authored
feat: Add sparse.interp function for numba backend (#903)
* Add sparse.interp function for numba backend * Add suggestions from review * Add shortcut for dense data to spare.interp * Move test to test_coo.py * Adjust test to compare sparse.interp to numpy.interp * Add test case for complex data * Use random test data * Fix failing pipeline tests * Fix doctest
1 parent 3a2e703 commit 0d0121a

File tree

4 files changed

+129
-0
lines changed

4 files changed

+129
-0
lines changed

sparse/numba_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
full,
106106
full_like,
107107
imag,
108+
interp,
108109
isinf,
109110
isnan,
110111
matmul,
@@ -252,6 +253,7 @@
252253
"int32",
253254
"int64",
254255
"int8",
256+
"interp",
255257
"isfinite",
256258
"isinf",
257259
"isnan",

sparse/numba_backend/_common.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3250,3 +3250,88 @@ def diff(x, axis=-1, n=1, prepend=None, append=None):
32503250
for _ in range(n):
32513251
result = result[(slice(None),) * axis + (slice(1, None),)] - result[(slice(None),) * axis + (slice(None, -1),)]
32523252
return result
3253+
3254+
3255+
def interp(x, xp, fp, left=None, right=None, period=None):
3256+
"""
3257+
An implementation of ``numpy.interp`` for sparse arrays.
3258+
3259+
Thanks to the function dispatch of numpy, this enables interpolation on sparse arrays
3260+
using the numpy universal function. This function effectively wraps ``np.interp`` by
3261+
calling it on the array data and the fill value. See the numpy documentation for
3262+
details on the parameters.
3263+
3264+
Parameters
3265+
----------
3266+
x : SparseArray
3267+
The x-coordinates at which to evaluate the interpolated values.
3268+
xp : 1-D sequence or SparseArray
3269+
The x-coordinates of the data points.
3270+
fp : 1-D sequence or SparseArray
3271+
The y-coordinates of the data points, same length as ``xp``.
3272+
left : float or complex, optional
3273+
Value to return for ``x < xp[0]``, default is ``fp[0]``.
3274+
right : float or complex, optional
3275+
Value to return for ``x > xp[-1]``, default is ``fp[-1]``.
3276+
period : None or float, optional
3277+
A period for the x-coordinates.
3278+
3279+
Returns
3280+
-------
3281+
out : SparseArray
3282+
The interpolated values, same shape as x.
3283+
3284+
See Also
3285+
--------
3286+
https://numpy.org/doc/stable/reference/generated/numpy.interp.html
3287+
3288+
Examples
3289+
--------
3290+
When interpolating a sparse array, its data and the fill value are interpolated. The
3291+
returned array is pruned. Therefore, the fill value and the number of nonzero
3292+
elements might change.
3293+
3294+
>>> import numpy as np
3295+
>>> xp = [1, 2, 3]
3296+
>>> fp = [3, 2, 0]
3297+
>>> y = np.interp(sparse.COO.from_numpy(np.array([0, 1, 1.5, 2.72, 3.14])), xp, fp)
3298+
>>> y.todense()
3299+
array([3. , 3. , 2.5 , 0.56, 0. ])
3300+
>>> y.fill_value
3301+
np.float64(3.0)
3302+
>>> y.nnz
3303+
3
3304+
"""
3305+
from ._compressed import GCXS
3306+
from ._coo import COO
3307+
from ._dok import DOK
3308+
3309+
# Densify sparse interpolants
3310+
if isinstance(xp, SparseArray):
3311+
xp = xp.todense()
3312+
if isinstance(fp, SparseArray):
3313+
fp = fp.todense()
3314+
3315+
def interp_func(xx):
3316+
return np.interp(xx, xp, fp, left=left, right=right, period=period)
3317+
3318+
# Shortcut for dense arrays
3319+
if not isinstance(x, SparseArray):
3320+
return interp_func(x)
3321+
3322+
# Define output type
3323+
out_kwargs = {}
3324+
out_type = COO
3325+
if isinstance(x, GCXS):
3326+
out_type = GCXS
3327+
out_kwargs["compressed_axes"] = x.compressed_axes
3328+
elif isinstance(x, DOK):
3329+
out_type = DOK
3330+
3331+
# Perform interpolation on sparse object
3332+
arr = as_coo(x)
3333+
data = interp_func(arr.data)
3334+
fill_value = interp_func(arr.fill_value)
3335+
return COO(data=data, coords=arr.coords, shape=arr.shape, fill_value=fill_value, prune=True).asformat(
3336+
out_type, **out_kwargs
3337+
)

sparse/numba_backend/tests/test_coo.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,3 +2070,44 @@ def test_diff_invalid_type():
20702070
a = np.arange(6).reshape(2, 3)
20712071
with pytest.raises(TypeError, match="must be a SparseArray"):
20722072
sparse.diff(a)
2073+
2074+
2075+
class TestInterp:
2076+
xp = [-1, 0, 1]
2077+
fp = [3, 2, 0]
2078+
2079+
@pytest.fixture(params=["coo", "dok", "gcxs", "dense"])
2080+
def x(self, request):
2081+
arr = sparse.random((10, 10, 10), fill_value=0)
2082+
if request.param == "dense":
2083+
return arr.todense()
2084+
return arr.asformat(request.param)
2085+
2086+
@pytest.mark.parametrize(
2087+
"xp",
2088+
[
2089+
xp,
2090+
sparse.COO.from_numpy(np.array(xp)),
2091+
],
2092+
)
2093+
@pytest.mark.parametrize(
2094+
"fp",
2095+
[
2096+
fp,
2097+
sparse.COO.from_numpy(np.array(fp)),
2098+
sparse.COO.from_numpy(np.array(fp) + 1j),
2099+
],
2100+
)
2101+
def test_interp(self, x, xp, fp):
2102+
def to_dense(arr):
2103+
if isinstance(x, sparse.SparseArray):
2104+
return arr.todense()
2105+
return arr
2106+
2107+
actual = sparse.interp(x, xp, fp)
2108+
expected = np.interp(to_dense(x), xp, fp)
2109+
2110+
assert isinstance(actual, type(x))
2111+
if isinstance(x, sparse.SparseArray):
2112+
assert actual.fill_value == fp[1]
2113+
np.testing.assert_array_equal(to_dense(actual), expected)

sparse/numba_backend/tests/test_namespace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_namespace():
8383
"int32",
8484
"int64",
8585
"int8",
86+
"interp",
8687
"isfinite",
8788
"isinf",
8889
"isnan",

0 commit comments

Comments
 (0)