File tree Expand file tree Collapse file tree 1 file changed +20
-6
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +20
-6
lines changed Original file line number Diff line number Diff line change @@ -1295,12 +1295,26 @@ def local_inplace_setsubtensor(fgraph, node):
12951295
12961296@node_rewriter ([AdvancedIncSubtensor1 ], inplace = True )
12971297def local_inplace_AdvancedIncSubtensor1 (fgraph , node ):
1298- if isinstance (node .op , AdvancedIncSubtensor1 ) and not node .op .inplace :
1299- new_op = node .op .clone_inplace ()
1300- new_node = new_op (* node .inputs )
1301- copy_stack_trace (node .outputs , new_node )
1302- return [new_node ]
1303- return False
1298+ if node .op .inplace :
1299+ return
1300+
1301+ x , y , idx = node .inputs
1302+ if fgraph .has_destroyers ([x ]):
1303+ # In this case we can't operate inplace, but if x is just an alloc of zeros
1304+ # We're better off duplicating it and then acting on it inplace.
1305+ if (
1306+ x .owner is not None
1307+ and isinstance (x .owner .op , Alloc )
1308+ and x .owner .op .value_is_scalar_zero (x .owner .inputs [0 ])
1309+ ):
1310+ x = x .owner .clone ().outputs [0 ]
1311+ else :
1312+ return None # Inplace isn't valid
1313+
1314+ new_op = node .op .clone_inplace ()
1315+ new_node = new_op (x , y , idx )
1316+ copy_stack_trace (node .outputs , new_node )
1317+ return [new_node ]
13041318
13051319
13061320compile .optdb .register (
You can’t perform that action at this time.
0 commit comments