Skip to content

Commit 449c2df

Browse files
committed
Refactor MLX subtensor dispatch and update test
Simplifies index normalization logic in MLX subtensor dispatch functions by separating basic and advanced indexing cases. Updates the advanced incsubtensor test to use vector array indices and a matching value shape for improved coverage.
1 parent cd7a2d0 commit 449c2df

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

pytensor/link/mlx/dispatch/subtensor.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.tensor.type_other import MakeSlice
1717

1818

19-
def normalize_indices_for_mlx(ilist, idx_list):
19+
def normalize_indices_for_mlx(indices):
2020
"""Convert numpy integers to Python integers for MLX indexing.
2121
2222
MLX requires index values to be Python int, not np.int64 or other NumPy types.
@@ -49,18 +49,19 @@ def normalize_element(element):
4949
else:
5050
return element
5151

52-
indices = indices_from_subtensor(ilist, idx_list)
5352
return tuple(normalize_element(idx) for idx in indices)
5453

5554

5655
@mlx_funcify.register(Subtensor)
5756
def mlx_funcify_Subtensor(op, node, **kwargs):
5857
"""MLX implementation of Subtensor."""
59-
idx_list = getattr(op, "idx_list", None)
58+
idx_list = op.idx_list
6059

6160
def subtensor(x, *ilists):
61+
# Convert ilist to indices using idx_list (basic subtensor)
62+
indices = indices_from_subtensor(ilists, idx_list)
6263
# Normalize indices to handle np.int64 and other NumPy types
63-
indices = normalize_indices_for_mlx(ilists, idx_list)
64+
indices = normalize_indices_for_mlx(indices)
6465
if len(indices) == 1:
6566
indices = indices[0]
6667

@@ -73,11 +74,11 @@ def subtensor(x, *ilists):
7374
@mlx_funcify.register(AdvancedSubtensor1)
7475
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
7576
"""MLX implementation of AdvancedSubtensor."""
76-
idx_list = getattr(op, "idx_list", None)
7777

7878
def advanced_subtensor(x, *ilists):
7979
# Normalize indices to handle np.int64 and other NumPy types
80-
indices = normalize_indices_for_mlx(ilists, idx_list)
80+
# Advanced indexing doesn't use idx_list or indices_from_subtensor
81+
indices = normalize_indices_for_mlx(ilists)
8182
if len(indices) == 1:
8283
indices = indices[0]
8384

@@ -87,12 +88,11 @@ def advanced_subtensor(x, *ilists):
8788

8889

8990
@mlx_funcify.register(IncSubtensor)
90-
@mlx_funcify.register(AdvancedIncSubtensor1)
9191
def mlx_funcify_IncSubtensor(op, node, **kwargs):
9292
"""MLX implementation of IncSubtensor."""
93-
idx_list = getattr(op, "idx_list", None)
93+
idx_list = op.idx_list
9494

95-
if getattr(op, "set_instead_of_inc", False):
95+
if op.set_instead_of_inc:
9696

9797
def mlx_fn(x, indices, y):
9898
if not op.inplace:
@@ -109,8 +109,10 @@ def mlx_fn(x, indices, y):
109109
return x
110110

111111
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
112+
# Convert ilist to indices using idx_list (basic inc_subtensor)
113+
indices = indices_from_subtensor(ilist, idx_list)
112114
# Normalize indices to handle np.int64 and other NumPy types
113-
indices = normalize_indices_for_mlx(ilist, idx_list)
115+
indices = normalize_indices_for_mlx(indices)
114116

115117
if len(indices) == 1:
116118
indices = indices[0]
@@ -121,11 +123,11 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
121123

122124

123125
@mlx_funcify.register(AdvancedIncSubtensor)
126+
@mlx_funcify.register(AdvancedIncSubtensor1)
124127
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
125128
"""MLX implementation of AdvancedIncSubtensor."""
126-
idx_list = getattr(op, "idx_list", None)
127129

128-
if getattr(op, "set_instead_of_inc", False):
130+
if op.set_instead_of_inc:
129131

130132
def mlx_fn(x, indices, y):
131133
if not op.inplace:
@@ -141,9 +143,10 @@ def mlx_fn(x, indices, y):
141143
x[indices] += y
142144
return x
143145

144-
def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
146+
def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
145147
# Normalize indices to handle np.int64 and other NumPy types
146-
indices = normalize_indices_for_mlx(ilist, idx_list)
148+
# Advanced indexing doesn't use idx_list or indices_from_subtensor
149+
indices = normalize_indices_for_mlx(ilist)
147150

148151
# For advanced indexing, if we have a single tuple of indices, unwrap it
149152
if len(indices) == 1:

tests/link/mlx/test_subtensor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,16 @@ def test_mlx_advanced_incsubtensor_with_numpy_int64():
395395
x_np = np.arange(15, dtype=np.float32).reshape((5, 3))
396396
x_pt = pt.constant(x_np)
397397

398-
# Value to set/increment
398+
# Value to set/increment - using 4 rows now for vector indexing
399399
y_pt = pt.as_tensor_variable(
400-
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
400+
np.array(
401+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
402+
dtype=np.float32,
403+
)
401404
)
402405

403-
# Advanced indexing set with array indices
404-
indices = [np.int64(0), np.int64(2)]
406+
# Advanced indexing set with vector array indices
407+
indices = np.array([0, 1, 2, 3], dtype=np.int64)
405408
out_pt = pt_subtensor.set_subtensor(x_pt[indices], y_pt)
406409
compare_mlx_and_py([], [out_pt], [])
407410

0 commit comments

Comments
 (0)