Skip to content

Commit aab601f

Browse files
Copilotjustinchuby
andauthored
Create a name fix pass to ensure unique names for all values and nodes (#124)
This PR implements a `NameFixPass` in `src/onnx_ir/passes/common/naming.py` to ensure that all values and nodes in an ONNX IR model have unique names according to the specified policy: 1. **Graph inputs and outputs have precedence** - Their names are preserved when possible, and only renamed if duplicates exist 1. Scoped unique names: Names in subgraphs are unique within their naming scopes and their parent scopes. 1. Subgraph inputs/outputs are processed 1. Support custom naming generators with generate_node_name and generate_value_name params. I create the `enter_graph` and `exit_graph` callbacks on the RecursiveIterator to allow users to handle getting into and out of graph scopes. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 5941e43 commit aab601f

File tree

4 files changed

+728
-0
lines changed

4 files changed

+728
-0
lines changed

src/onnx_ir/passes/common/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"InlinePass",
1212
"LiftConstantsToInitializersPass",
1313
"LiftSubgraphInitializersToMainGraphPass",
14+
"NameFixPass",
1415
"RemoveInitializersFromInputsPass",
1516
"RemoveUnusedFunctionsPass",
1617
"RemoveUnusedNodesPass",
@@ -38,6 +39,7 @@
3839
DeduplicateInitializersPass,
3940
)
4041
from onnx_ir.passes.common.inliner import InlinePass
42+
from onnx_ir.passes.common.naming import NameFixPass
4143
from onnx_ir.passes.common.onnx_checker import CheckerPass
4244
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
4345
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Name fix pass for ensuring unique names for all values and nodes."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"NameFixPass",
9+
"NameGenerator",
10+
"SimpleNameGenerator",
11+
]
12+
13+
import collections
14+
import logging
15+
from typing import Protocol
16+
17+
import onnx_ir as ir
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class NameGenerator(Protocol):
23+
def generate_node_name(self, node: ir.Node) -> str:
24+
"""Generate a preferred name for a node."""
25+
...
26+
27+
def generate_value_name(self, value: ir.Value) -> str:
28+
"""Generate a preferred name for a value."""
29+
...
30+
31+
32+
class SimpleNameGenerator(NameGenerator):
33+
"""Base class for name generation functions."""
34+
35+
def generate_node_name(self, node: ir.Node) -> str:
36+
"""Generate a preferred name for a node."""
37+
return node.name or "node"
38+
39+
def generate_value_name(self, value: ir.Value) -> str:
40+
"""Generate a preferred name for a value."""
41+
return value.name or "v"
42+
43+
44+
class NameFixPass(ir.passes.InPlacePass):
45+
"""Pass for fixing names to ensure all values and nodes have unique names.
46+
47+
This pass ensures that:
48+
1. Graph inputs and outputs have unique names (take precedence)
49+
2. All intermediate values have unique names (assign names to unnamed values)
50+
3. All values in subgraphs have unique names within their graph and parent graphs
51+
4. All nodes have unique names within their graph
52+
53+
The pass maintains global uniqueness across the entire model.
54+
55+
You can customize the name generation functions for nodes and values by passing
56+
a subclass of :class:`NameGenerator`.
57+
58+
For example, you can use a custom naming scheme like this::
59+
60+
class CustomNameGenerator:
61+
def custom_node_name(node: ir.Node) -> str:
62+
return f"custom_node_{node.op_type}"
63+
64+
def custom_value_name(value: ir.Value) -> str:
65+
return f"custom_value_{value.type}"
66+
67+
name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator())
68+
69+
.. versionadded:: 0.1.6
70+
"""
71+
72+
def __init__(
73+
self,
74+
name_generator: NameGenerator | None = None,
75+
) -> None:
76+
"""Initialize the NameFixPass with custom name generation functions.
77+
78+
Args:
79+
name_generator (NameGenerator, optional): An instance of a subclass of
80+
:class:`NameGenerator` to customize name generation for nodes and values.
81+
If not provided, defaults to a basic implementation that uses
82+
the node's or value's existing name or a generic name like "node" or "v".
83+
"""
84+
super().__init__()
85+
self._name_generator = name_generator or SimpleNameGenerator()
86+
87+
def call(self, model: ir.Model) -> ir.passes.PassResult:
88+
# Process the main graph
89+
modified = self._fix_graph_names(model.graph)
90+
91+
# Process functions
92+
for function in model.functions.values():
93+
modified = self._fix_graph_names(function) or modified
94+
95+
return ir.passes.PassResult(model, modified=modified)
96+
97+
def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
98+
"""Fix names in a graph and return whether modifications were made."""
99+
modified = False
100+
101+
# Set to track which values have been assigned names
102+
seen_values: set[ir.Value] = set()
103+
104+
# The first set is a dummy placeholder so that there is always a [-1] scope for access
105+
# (even though we don't write to it)
106+
scoped_used_value_names: list[set[str]] = [set()]
107+
scoped_used_node_names: list[set[str]] = [set()]
108+
109+
# Counters for generating unique names (using list to pass by reference)
110+
value_counter = collections.Counter()
111+
node_counter = collections.Counter()
112+
113+
def enter_graph(graph_like) -> None:
114+
"""Callback for entering a subgraph."""
115+
# Initialize new scopes with all names from the parent scope
116+
scoped_used_value_names.append(set(scoped_used_value_names[-1]))
117+
scoped_used_node_names.append(set())
118+
119+
nonlocal modified
120+
121+
# Step 1: Fix graph input names first (they have precedence)
122+
for input_value in graph_like.inputs:
123+
if self._process_value(
124+
input_value, scoped_used_value_names[-1], seen_values, value_counter
125+
):
126+
modified = True
127+
128+
# Step 2: Fix graph output names (they have precedence)
129+
for output_value in graph_like.outputs:
130+
if self._process_value(
131+
output_value, scoped_used_value_names[-1], seen_values, value_counter
132+
):
133+
modified = True
134+
135+
if isinstance(graph_like, ir.Graph):
136+
# For graphs, also fix initializers
137+
for initializer in graph_like.initializers.values():
138+
if self._process_value(
139+
initializer, scoped_used_value_names[-1], seen_values, value_counter
140+
):
141+
modified = True
142+
143+
def exit_graph(_) -> None:
144+
"""Callback for exiting a subgraph."""
145+
# Pop the current scope
146+
scoped_used_value_names.pop()
147+
scoped_used_node_names.pop()
148+
149+
# Step 3: Process all nodes and their values
150+
for node in ir.traversal.RecursiveGraphIterator(
151+
graph_like, enter_graph=enter_graph, exit_graph=exit_graph
152+
):
153+
# Fix node name
154+
if not node.name:
155+
if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
156+
modified = True
157+
else:
158+
if self._fix_duplicate_node_name(
159+
node, scoped_used_node_names[-1], node_counter
160+
):
161+
modified = True
162+
163+
# Fix input value names (only if not already processed)
164+
for input_value in node.inputs:
165+
if input_value is not None:
166+
if self._process_value(
167+
input_value, scoped_used_value_names[-1], seen_values, value_counter
168+
):
169+
modified = True
170+
171+
# Fix output value names (only if not already processed)
172+
for output_value in node.outputs:
173+
if self._process_value(
174+
output_value, scoped_used_value_names[-1], seen_values, value_counter
175+
):
176+
modified = True
177+
178+
return modified
179+
180+
def _process_value(
181+
self,
182+
value: ir.Value,
183+
used_value_names: set[str],
184+
seen_values: set[ir.Value],
185+
value_counter: collections.Counter,
186+
) -> bool:
187+
"""Process a value only if it hasn't been processed before."""
188+
if value in seen_values:
189+
return False
190+
191+
modified = False
192+
193+
if not value.name:
194+
modified = self._assign_value_name(value, used_value_names, value_counter)
195+
else:
196+
old_name = value.name
197+
modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
198+
if modified:
199+
assert value.graph is not None
200+
if value.is_initializer():
201+
value.graph.initializers.pop(old_name)
202+
# Add the initializer back with the new name
203+
value.graph.initializers.add(value)
204+
205+
# Record the final name for this value
206+
assert value.name is not None
207+
seen_values.add(value)
208+
return modified
209+
210+
def _assign_value_name(
211+
self, value: ir.Value, used_names: set[str], counter: collections.Counter
212+
) -> bool:
213+
"""Assign a name to an unnamed value. Returns True if modified."""
214+
assert not value.name, (
215+
"value should not have a name already if function is called correctly"
216+
)
217+
218+
preferred_name = self._name_generator.generate_value_name(value)
219+
value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
220+
logger.debug("Assigned name %s to unnamed value", value.name)
221+
return True
222+
223+
def _assign_node_name(
224+
self, node: ir.Node, used_names: set[str], counter: collections.Counter
225+
) -> bool:
226+
"""Assign a name to an unnamed node. Returns True if modified."""
227+
assert not node.name, (
228+
"node should not have a name already if function is called correctly"
229+
)
230+
231+
preferred_name = self._name_generator.generate_node_name(node)
232+
node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
233+
logger.debug("Assigned name %s to unnamed node", node.name)
234+
return True
235+
236+
def _fix_duplicate_value_name(
237+
self, value: ir.Value, used_names: set[str], counter: collections.Counter
238+
) -> bool:
239+
"""Fix a value's name if it conflicts with existing names. Returns True if modified."""
240+
original_name = value.name
241+
242+
assert original_name, (
243+
"value should have a name already if function is called correctly"
244+
)
245+
246+
if original_name not in used_names:
247+
# Name is unique, just record it
248+
used_names.add(original_name)
249+
return False
250+
251+
# If name is already used, make it unique
252+
base_name = self._name_generator.generate_value_name(value)
253+
value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
254+
logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
255+
return True
256+
257+
def _fix_duplicate_node_name(
258+
self, node: ir.Node, used_names: set[str], counter: collections.Counter
259+
) -> bool:
260+
"""Fix a node's name if it conflicts with existing names. Returns True if modified."""
261+
original_name = node.name
262+
263+
assert original_name, "node should have a name already if function is called correctly"
264+
265+
if original_name not in used_names:
266+
# Name is unique, just record it
267+
used_names.add(original_name)
268+
return False
269+
270+
# If name is already used, make it unique
271+
base_name = self._name_generator.generate_node_name(node)
272+
node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
273+
logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
274+
return True
275+
276+
277+
def _find_and_record_next_unique_name(
278+
preferred_name: str, used_names: set[str], counter: collections.Counter
279+
) -> str:
280+
"""Generate a unique name based on the preferred name and current counter."""
281+
new_name = preferred_name
282+
while new_name in used_names:
283+
counter[preferred_name] += 1
284+
new_name = f"{preferred_name}_{counter[preferred_name]}"
285+
used_names.add(new_name)
286+
return new_name

0 commit comments

Comments
 (0)