Skip to content

Commit eea38f3

Browse files
committed
[IR] extract update_graph_outputs in a helper (#2294)
1 parent ac87a1c commit eea38f3

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

onnxscript/ir/_convenience/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
336336
return values
337337

338338

339+
def _update_graph_or_function_outputs(
340+
graph_or_function: _core.Graph | _core.Function,
341+
old_values: Sequence[_core.Value],
342+
new_values: Sequence[_core.Value],
343+
):
344+
"""Update graph/function outputs"""
345+
replacement_mapping = dict(zip(old_values, new_values))
346+
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
347+
if graph_or_function_output in replacement_mapping:
348+
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
349+
350+
339351
def replace_nodes_and_values(
340352
graph_or_function: _core.Graph | _core.Function,
341353
/,
@@ -368,10 +380,7 @@ def replace_nodes_and_values(
368380
# Reconnect the users of the deleted values to use the new values
369381
replace_all_uses_with(old_values, new_values)
370382
# Update graph/function outputs if the node generates output
371-
replacement_mapping = dict(zip(old_values, new_values))
372-
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
373-
if graph_or_function_output in replacement_mapping:
374-
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
383+
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)
375384

376385
# insert new nodes after the index node
377386
graph_or_function.insert_after(insertion_point, new_nodes)

0 commit comments

Comments
 (0)