1616from pytensor .tensor .type_other import MakeSlice
1717
1818
19- def normalize_indices_for_mlx (ilist , idx_list ):
19+ def normalize_indices_for_mlx (indices ):
2020 """Convert numpy integers to Python integers for MLX indexing.
2121
2222 MLX requires index values to be Python int, not np.int64 or other NumPy types.
@@ -49,18 +49,19 @@ def normalize_element(element):
4949 else :
5050 return element
5151
52- indices = indices_from_subtensor (ilist , idx_list )
5352 return tuple (normalize_element (idx ) for idx in indices )
5453
5554
5655@mlx_funcify .register (Subtensor )
5756def mlx_funcify_Subtensor (op , node , ** kwargs ):
5857 """MLX implementation of Subtensor."""
59- idx_list = getattr ( op , " idx_list" , None )
58+ idx_list = op . idx_list
6059
6160 def subtensor (x , * ilists ):
61+ # Convert ilist to indices using idx_list (basic subtensor)
62+ indices = indices_from_subtensor (ilists , idx_list )
6263 # Normalize indices to handle np.int64 and other NumPy types
63- indices = normalize_indices_for_mlx (ilists , idx_list )
64+ indices = normalize_indices_for_mlx (indices )
6465 if len (indices ) == 1 :
6566 indices = indices [0 ]
6667
@@ -73,11 +74,11 @@ def subtensor(x, *ilists):
7374@mlx_funcify .register (AdvancedSubtensor1 )
7475def mlx_funcify_AdvancedSubtensor (op , node , ** kwargs ):
7576 """MLX implementation of AdvancedSubtensor."""
76- idx_list = getattr (op , "idx_list" , None )
7777
7878 def advanced_subtensor (x , * ilists ):
7979 # Normalize indices to handle np.int64 and other NumPy types
80- indices = normalize_indices_for_mlx (ilists , idx_list )
80+ # Advanced indexing doesn't use idx_list or indices_from_subtensor
81+ indices = normalize_indices_for_mlx (ilists )
8182 if len (indices ) == 1 :
8283 indices = indices [0 ]
8384
@@ -87,12 +88,11 @@ def advanced_subtensor(x, *ilists):
8788
8889
8990@mlx_funcify .register (IncSubtensor )
90- @mlx_funcify .register (AdvancedIncSubtensor1 )
9191def mlx_funcify_IncSubtensor (op , node , ** kwargs ):
9292 """MLX implementation of IncSubtensor."""
93- idx_list = getattr ( op , " idx_list" , None )
93+ idx_list = op . idx_list
9494
95- if getattr ( op , " set_instead_of_inc" , False ) :
95+ if op . set_instead_of_inc :
9696
9797 def mlx_fn (x , indices , y ):
9898 if not op .inplace :
@@ -109,8 +109,10 @@ def mlx_fn(x, indices, y):
109109 return x
110110
111111 def incsubtensor (x , y , * ilist , mlx_fn = mlx_fn , idx_list = idx_list ):
112+ # Convert ilist to indices using idx_list (basic inc_subtensor)
113+ indices = indices_from_subtensor (ilist , idx_list )
112114 # Normalize indices to handle np.int64 and other NumPy types
113- indices = normalize_indices_for_mlx (ilist , idx_list )
115+ indices = normalize_indices_for_mlx (indices )
114116
115117 if len (indices ) == 1 :
116118 indices = indices [0 ]
@@ -121,11 +123,11 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
121123
122124
123125@mlx_funcify .register (AdvancedIncSubtensor )
126+ @mlx_funcify .register (AdvancedIncSubtensor1 )
124127def mlx_funcify_AdvancedIncSubtensor (op , node , ** kwargs ):
125128 """MLX implementation of AdvancedIncSubtensor."""
126- idx_list = getattr (op , "idx_list" , None )
127129
128- if getattr ( op , " set_instead_of_inc" , False ) :
130+ if op . set_instead_of_inc :
129131
130132 def mlx_fn (x , indices , y ):
131133 if not op .inplace :
@@ -141,9 +143,10 @@ def mlx_fn(x, indices, y):
141143 x [indices ] += y
142144 return x
143145
144- def advancedincsubtensor (x , y , * ilist , mlx_fn = mlx_fn , idx_list = idx_list ):
146+ def advancedincsubtensor (x , y , * ilist , mlx_fn = mlx_fn ):
145147 # Normalize indices to handle np.int64 and other NumPy types
146- indices = normalize_indices_for_mlx (ilist , idx_list )
148+ # Advanced indexing doesn't use idx_list or indices_from_subtensor
149+ indices = normalize_indices_for_mlx (ilist )
147150
148151 # For advanced indexing, if we have a single tuple of indices, unwrap it
149152 if len (indices ) == 1 :
0 commit comments