Skip to content

Commit 3521a40

Browse files
committed
Test numba slice boxing and fix representation of None stop with negative step
1 parent 57856ce commit 3521a40

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,43 @@ def box_slice(typ, val, c):
4646
"""
4747
start = c.builder.extract_value(val, 0)
4848
stop = c.builder.extract_value(val, 1)
49+
step = c.builder.extract_value(val, 2) if typ.has_step else None
4950

51+
# Numba uses sys.maxsize and -sys.maxsize-1 to represent None
52+
# We want to use None in the Python representation
5053
none_val = ir.Constant(ir.IntType(64), sys.maxsize)
54+
neg_none_val = ir.Constant(ir.IntType(64), -sys.maxsize - 1)
55+
none_obj = c.pyapi.get_null_object()
5156

52-
start_is_none = c.builder.icmp_signed("==", start, none_val)
5357
start = c.builder.select(
54-
start_is_none,
55-
c.pyapi.get_null_object(),
58+
c.builder.icmp_signed("==", start, none_val),
59+
none_obj,
5660
c.box(types.int64, start),
5761
)
5862

59-
stop_is_none = c.builder.icmp_signed("==", stop, none_val)
63+
# None stop is represented as neg_none_val when step is negative
64+
if step is not None:
65+
stop_none_val = c.builder.select(
66+
c.builder.icmp_signed(">", step, ir.Constant(ir.IntType(64), 0)),
67+
none_val,
68+
neg_none_val,
69+
)
70+
else:
71+
stop_none_val = none_val
6072
stop = c.builder.select(
61-
stop_is_none,
62-
c.pyapi.get_null_object(),
73+
c.builder.icmp_signed("==", stop, stop_none_val),
74+
none_obj,
6375
c.box(types.int64, stop),
6476
)
6577

66-
if typ.has_step:
67-
step = c.builder.extract_value(val, 2)
68-
step_is_none = c.builder.icmp_signed("==", step, none_val)
78+
if step is not None:
6979
step = c.builder.select(
70-
step_is_none,
71-
c.pyapi.get_null_object(),
80+
c.builder.icmp_signed("==", step, none_val),
81+
none_obj,
7282
c.box(types.int64, step),
7383
)
7484
else:
75-
step = c.pyapi.get_null_object()
85+
step = none_obj
7686

7787
slice_val = slice_new(c.pyapi, start, stop, step)
7888

tests/link/numba/test_subtensor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import numpy as np
44
import pytest
55

6+
import pytensor.scalar as ps
67
import pytensor.tensor as pt
8+
from pytensor import Mode, as_symbolic
79
from pytensor.tensor import as_tensor
810
from pytensor.tensor.subtensor import (
911
AdvancedIncSubtensor,
@@ -24,6 +26,45 @@
2426
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
2527

2628

29+
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
30+
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
31+
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
32+
def test_slice(start, stop, step):
33+
x = ps.int64("x")
34+
35+
sym_slice = as_symbolic(
36+
slice(
37+
x if start == "x" else start,
38+
x if stop == "x" else stop,
39+
x if step == "x" else step,
40+
)
41+
)
42+
43+
no_opt_mode = Mode(linker="numba", optimizer=None)
44+
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
45+
assert isinstance(evaled_slice, slice)
46+
if start == "x":
47+
assert evaled_slice.start == -5
48+
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
49+
# Numba can convert to 0 (and sometimes does) in this case
50+
assert evaled_slice.start in (None, 0)
51+
else:
52+
assert evaled_slice.start == start
53+
54+
if stop == "x":
55+
assert evaled_slice.stop == -5
56+
else:
57+
assert evaled_slice.stop == stop
58+
59+
if step == "x":
60+
assert evaled_slice.step == -5
61+
elif step is None:
62+
# Numba can convert to 1 (and sometimes does) in this case
63+
assert evaled_slice.step in (None, 1)
64+
else:
65+
assert evaled_slice.step == step
66+
67+
2768
@pytest.mark.parametrize(
2869
"x, indices",
2970
[

0 commit comments

Comments
 (0)