@@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
336336 return values
337337
338338
339+ def _update_graph_or_function_outputs (
340+ graph_or_function : _core .Graph | _core .Function ,
341+ old_values : Sequence [_core .Value ],
342+ new_values : Sequence [_core .Value ],
343+ ):
344+ """Update graph/function outputs"""
345+ replacement_mapping = dict (zip (old_values , new_values ))
346+ for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
347+ if graph_or_function_output in replacement_mapping :
348+ graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
349+
350+
339351def replace_nodes_and_values (
340352 graph_or_function : _core .Graph | _core .Function ,
341353 / ,
@@ -368,10 +380,7 @@ def replace_nodes_and_values(
368380 # Reconnect the users of the deleted values to use the new values
369381 replace_all_uses_with (old_values , new_values )
370382 # Update graph/function outputs if the node generates output
371- replacement_mapping = dict (zip (old_values , new_values ))
372- for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
373- if graph_or_function_output in replacement_mapping :
374- graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
383+ _update_graph_or_function_outputs (graph_or_function , old_values , new_values )
375384
376385 # insert new nodes after the index node
377386 graph_or_function .insert_after (insertion_point , new_nodes )
0 commit comments