Skip to content

Commit a764fb1

Browse files
authored
Implement replace_all_uses_with on Value (#235)
Given it is a fundamental operation in graph transformation, I added it as a method on `Value` so that it is more convenient and discoverable. The ir.convenience.replace_all_uses_with function is preserved to handle batch processing cases on multiple values. Also added a `replace_graph_outputs` option to allow users to replace graph outputs as well. Users are still responsible to assign the output name to the replacement value if they want to preserve the signature. ### BC breaking An error is raised when `replace_graph_outputs` is False && when the value to replace is a graph output. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 74d9692 commit a764fb1

File tree

7 files changed

+86
-23
lines changed

7 files changed

+86
-23
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def convert_attributes(
280280
def replace_all_uses_with(
281281
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
282282
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
283+
replace_graph_outputs: bool = False,
283284
) -> None:
284285
"""Replace all uses of the given values with the replacements.
285286
@@ -318,9 +319,22 @@ def replace_all_uses_with(
318319
replaced are part of the graph outputs. Be sure to remove the old nodes
319320
from the graph using ``graph.remove()`` if they are no longer needed.
320321
322+
.. versionadded:: 0.1.12
323+
The ``replace_graph_outputs`` parameter is added.
324+
325+
.. versionadded:: 0.1.12
326+
ValueError is raised when ``replace_graph_outputs`` is False && when the value to
327+
replace is a graph output.
328+
321329
Args:
322330
values: The value or values to be replaced.
323331
replacements: The new value or values to use as inputs.
332+
replace_graph_outputs: If True, graph outputs that reference the values
333+
being replaced will also be updated to reference the replacements.
334+
335+
Raises:
336+
ValueError: When ``replace_graph_outputs`` is False && when the value to
337+
replace is a graph output.
324338
"""
325339
if not isinstance(values, Sequence):
326340
values = (values,)
@@ -329,8 +343,7 @@ def replace_all_uses_with(
329343
if len(values) != len(replacements):
330344
raise ValueError("The number of values and replacements must match.")
331345
for value, replacement in zip(values, replacements):
332-
for user_node, index in tuple(value.uses()):
333-
user_node.replace_input_with(index, replacement)
346+
value.replace_all_uses_with(replacement, replace_graph_outputs=replace_graph_outputs)
334347

335348

336349
def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
@@ -408,12 +421,7 @@ def replace_nodes_and_values(
408421
new_value.name = old_value.name if old_value.name is not None else new_value.name
409422

410423
# Reconnect the users of the deleted values to use the new values
411-
replace_all_uses_with(old_values, new_values)
412-
# Update graph/function outputs if the node generates output
413-
replacement_mapping = dict(zip(old_values, new_values))
414-
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
415-
if graph_or_function_output in replacement_mapping:
416-
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
424+
replace_all_uses_with(old_values, new_values, replace_graph_outputs=True)
417425

418426
# insert new nodes after the index node
419427
graph_or_function.insert_after(insertion_point, new_nodes)

src/onnx_ir/_core.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,6 +2448,55 @@ def is_initializer(self) -> bool:
24482448
"""Whether the value is an initializer of a graph."""
24492449
return self._is_initializer
24502450

2451+
def replace_all_uses_with(
2452+
self, replacement: Value, /, replace_graph_outputs: bool = False
2453+
) -> None:
2454+
"""Replace all uses of this value with another value.
2455+
2456+
If the value is an output of a graph and ``replace_graph_outputs`` is ``True``,
2457+
the graph output will also be replaced. Be careful when a value appears multiple times
2458+
in the graph outputs - this is invalid. An identity node will need to be added on each
2459+
duplicated outputs to ensure a valid ONNX graph.
2460+
2461+
You may also want to assign the name of this value to the replacement value
2462+
to maintain the name when it is a graph output.
2463+
2464+
To replace usage of a sequence of values with another sequence of values, consider using
2465+
:func:`onnx_ir.convenience.replace_all_uses_with`.
2466+
2467+
.. versionadded:: 0.1.12
2468+
2469+
Args:
2470+
replacement: The value to replace all uses with.
2471+
replace_graph_outputs: If True, graph outputs that reference this value
2472+
will also be updated to reference the replacement.
2473+
2474+
Raises:
2475+
ValueError: When ``replace_graph_outputs`` is False && when the value to
2476+
replace is a graph output.
2477+
"""
2478+
# NOTE: Why we don't replace the value name when the value is an output:
2479+
# When the replacement value is already an output of the graph, renaming it
2480+
# to the name of this value will cause name conflicts. It is better to let
2481+
# the user handle the renaming explicitly and insert identity nodes if needed.
2482+
if self.is_graph_output():
2483+
graph = self.graph
2484+
assert graph is not None
2485+
2486+
if not replace_graph_outputs:
2487+
raise ValueError(
2488+
f"{self!r} is an output of graph {graph.name!r}. "
2489+
"Set replace_graph_outputs=True or replace the graph output frist before "
2490+
"calling replace_all_uses_with."
2491+
)
2492+
2493+
for i, output in enumerate(graph.outputs):
2494+
if output is self:
2495+
graph.outputs[i] = replacement
2496+
2497+
for user_node, index in self.uses():
2498+
user_node.replace_input_with(index, replacement)
2499+
24512500

24522501
@deprecated("Input is deprecated since 0.1.9. Use ir.val(...) instead.")
24532502
def Input(

src/onnx_ir/_protocols.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,17 @@ def is_graph_output(self) -> bool:
203203
"""Whether this value is an output of a graph."""
204204
...
205205

206+
def replace_all_uses_with(
207+
self, new_value: ValueProtocol | None, replace_graph_outputs: bool = False
208+
) -> None:
209+
"""Replace all uses of this value with the given new value.
210+
211+
Args:
212+
new_value: The new value to replace this value with.
213+
replace_graph_outputs: Whether to replace graph outputs that use this value.
214+
"""
215+
...
216+
206217

207218
@typing.runtime_checkable
208219
class NodeProtocol(Protocol):

src/onnx_ir/passes/common/common_subexpression_elimination.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def _remove_node_and_replace_values(
150150
remove_values: The values to replace.
151151
new_values: The values to replace with.
152152
"""
153-
# Reconnect the users of the deleted values to use the new values
154-
ir.convenience.replace_all_uses_with(remove_values, new_values)
155153
# Update graph/function outputs if the node generates output
156154
if any(remove_value.is_graph_output() for remove_value in remove_values):
157155
replacement_mapping = dict(zip(remove_values, new_values))
@@ -185,6 +183,9 @@ def _remove_node_and_replace_values(
185183
new_value.name = graph_output.name
186184
graph.outputs[idx] = new_value
187185

186+
# Reconnect the users of the deleted values to use the new values
187+
ir.convenience.replace_all_uses_with(remove_values, new_values)
188+
188189
graph.remove(remove_node, safe=True)
189190

190191

src/onnx_ir/passes/common/constant_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
7878
assert node.graph is not None
7979
node.graph.register_initializer(initializer)
8080
# Replace the constant node with the initializer
81-
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
81+
node.outputs[0].replace_all_uses_with(initializer)
8282
node.graph.remove(node, safe=True)
8383
count += 1
8484
logger.debug(

src/onnx_ir/passes/common/identity_elimination.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,14 @@ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
105105

106106
# Case 1 & 2 (merged): Eliminate the identity node
107107
# Replace all uses of output with input
108-
ir.convenience.replace_all_uses_with(output_value, input_value)
108+
ir.convenience.replace_all_uses_with(
109+
output_value, input_value, replace_graph_outputs=True
110+
)
109111

110112
# If output is a graph output, we need to rename input and update graph outputs
111113
if output_is_graph_output:
112-
# Store the original output name
113-
original_output_name = output_value.name
114-
115114
# Update the input value to have the output's name
116-
input_value.name = original_output_name
117-
118-
# Update graph outputs to point to the input value
119-
for idx, graph_output in enumerate(graph_like.outputs):
120-
if graph_output is output_value:
121-
graph_like.outputs[idx] = input_value
115+
input_value.name = output_value.name
122116

123117
# Remove the identity node
124118
graph_like.remove(node, safe=True)

src/onnx_ir/passes/common/initializer_deduplication.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
100100
if key in initializers:
101101
modified = True
102102
initializer_to_keep = initializers[key] # type: ignore[index]
103-
ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
103+
initializer.replace_all_uses_with(initializer_to_keep)
104104
assert initializer.name is not None
105105
graph.initializers.pop(initializer.name)
106106
logger.info(
@@ -165,7 +165,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
165165
continue
166166
modified = True
167167
initializer_to_keep = initializers[key] # type: ignore[index]
168-
ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
168+
initializer.replace_all_uses_with(initializer_to_keep)
169169
assert initializer.name is not None
170170
graph.initializers.pop(initializer.name)
171171
logger.info(

0 commit comments

Comments
 (0)