55from pytensor import Variable
66from pytensor .graph import Constant , node_rewriter
77from pytensor .graph .rewriting .basic import copy_stack_trace
8- from pytensor .npy_2_compat import normalize_axis_tuple
8+ from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
99from pytensor .scalar import basic as ps
1010from pytensor .tensor .basic import (
1111 Alloc ,
3232 SpecifyShape ,
3333 specify_shape ,
3434)
35+ from pytensor .tensor .special import Softmax , softmax
3536from pytensor .tensor .subtensor import (
3637 AdvancedSubtensor1 ,
3738 Subtensor ,
@@ -51,6 +52,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
5152 return tuple (i for i , idx in enumerate (idxs ) if not isinstance (idx , slice ))
5253
5354
55+ def _ndim_dropped_left_of_axis_by_basic_index (
56+ idxs : Sequence [slice | int ], axis : int
57+ ) -> int :
58+ return len (_dims_dropped_by_basic_index (idxs [:axis ]))
59+
60+
61+ def _axis_is_indexed_by_basic_index (
62+ idxs : Sequence [slice | int ], axis : int | Sequence [int ]
63+ ) -> bool :
64+ if isinstance (axis , int ):
65+ axis = (axis ,)
66+ return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
67+
68+
5469@register_canonicalize
5570@register_stabilize
5671@register_specialize
@@ -241,6 +256,84 @@ def local_subtensor_of_reduce(fgraph, node):
241256 return [out ]
242257
243258
259+ @register_canonicalize
260+ @register_specialize
261+ @node_rewriter ([Subtensor ])
262+ def local_subtensor_of_softmax (fgraph , node ):
263+ """Lift a Subtensor through a Softmax.
264+
265+ softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
266+ softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
267+
268+ If part of the indexing acts on the axis of reduction, we split it
269+ softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
270+
271+ """
272+ sm , * idx = node .inputs
273+
274+ if not (sm .owner and isinstance (sm .owner .op , Softmax )):
275+ return None
276+
277+ if len (fgraph .clients [sm ]) > 1 :
278+ return None
279+
280+ [x ] = sm .owner .inputs
281+ axis = sm .owner .op .axis
282+
283+ if axis is None :
284+ if x .type .ndim == 1 :
285+ axis = 0
286+ else :
287+ # All dimensions are mixed, we can't lift the subtensor
288+ return None
289+ else :
290+ # Softmax currently only allows None or a single integer axis
291+ # Unlike CAReduce it does not normalize negative indices
292+ axis = normalize_axis_index (axis , sm .ndim )
293+
294+ [old_out ] = node .outputs
295+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
296+
297+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
298+ # If there are more dimensions being indexed, we can split them
299+ # And lift the non-axis indexes while keeping the axis index
300+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
301+ if len (real_indices ) > 1 and sm .type .ndim > 1 :
302+ # Split the subtensor
303+ idx_to_keep = idx_tuple [axis ]
304+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
305+
306+ # Lift the non-axis indexes by calling the rewrite itself
307+ opt_sm = sm [idxs_to_lift ]
308+ [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
309+ copy_stack_trace ([old_out , sm ], opt_sm )
310+
311+ # Then reintroduce the axis index
312+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
313+ idx_tuple , axis
314+ )
315+ new_axis = axis - ndim_reduced_left
316+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
317+ new_out = opt_sm [idxs_to_keep ]
318+ copy_stack_trace (old_out , new_out )
319+ return [new_out ]
320+
321+ else :
322+ return None
323+
324+ # Index input to softmax
325+ x_sub = x [idx_tuple ]
326+
327+ # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
328+ axis -= len (
329+ [idx_item for idx_item in idx_tuple [:axis ] if not isinstance (idx_item , slice )]
330+ )
331+
332+ out = softmax (x_sub , axis = axis )
333+ copy_stack_trace (old_out , out )
334+ return [out ]
335+
336+
244337@register_canonicalize ("shape_unsafe" )
245338@register_specialize ("shape_unsafe" )
246339@node_rewriter ([Subtensor ])
0 commit comments