Skip to content

Commit f04720d

Browse files
authored
[IR] Record owning graph for input/output/initializers (#2282)
Fix #1440 by pointing graph input, output and initializers back to the Graph using a tracked list. Users can now check if a value is a graph input/output/initializer, and find the owning graph of a value with `.graph`.
1 parent 8d98094 commit f04720d

File tree

9 files changed

+833
-69
lines changed

9 files changed

+833
-69
lines changed

onnxscript/ir/_convenience/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
321321
Returns:
322322
A dictionary mapping names to values.
323323
"""
324-
values = {}
324+
values: dict[str, _core.Value] = {}
325325
values.update(graph.initializers)
326326
# The names of the values can be None or "", which we need to exclude
327327
for input in graph.inputs:

onnxscript/ir/_core.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
Generic,
3232
Iterable,
3333
Iterator,
34+
MutableMapping,
35+
MutableSequence,
3436
NamedTuple,
3537
OrderedDict,
3638
Sequence,
@@ -46,6 +48,7 @@
4648
from onnxscript.ir import (
4749
_display,
4850
_enums,
51+
_graph_containers,
4952
_linked_list,
5053
_metadata,
5154
_name_authority,
@@ -1746,18 +1749,19 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
17461749
17471750
To find all the nodes that use this value as an input, call :meth:`uses`.
17481751
1749-
To check if the value is an output of a graph, call :meth:`is_graph_output`.
1752+
To check if the value is an is an input, output or initializer of a graph,
1753+
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
17501754
1751-
Attributes:
1752-
name: The name of the value. A value is always named when it is part of a graph.
1753-
shape: The shape of the value.
1754-
type: The type of the value.
1755-
metadata_props: Metadata.
1755+
Use :meth:`graph` to get the graph that owns the value.
17561756
"""
17571757

17581758
__slots__ = (
17591759
"_const_value",
1760+
"_graph",
17601761
"_index",
1762+
"_is_graph_input",
1763+
"_is_graph_output",
1764+
"_is_initializer",
17611765
"_metadata",
17621766
"_metadata_props",
17631767
"_name",
@@ -1808,6 +1812,14 @@ def __init__(
18081812
self._uses: dict[Usage, None] = {}
18091813
self.doc_string = doc_string
18101814

1815+
# The graph this value belongs to. It is set *only* when the value is added as
1816+
# a graph input, output or initializer.
1817+
# The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
1818+
self._graph: Graph | None = None
1819+
self._is_graph_input: bool = False
1820+
self._is_graph_output: bool = False
1821+
self._is_initializer: bool = False
1822+
18111823
def __repr__(self) -> str:
18121824
value_name = self.name if self.name else "anonymous:" + str(id(self))
18131825
type_text = f", type={self.type!r}" if self.type is not None else ""
@@ -1846,11 +1858,35 @@ def _constant_tensor_part(self) -> str:
18461858
return f"{{{self.const_value.__class__.__name__}(...)}}"
18471859
return ""
18481860

1861+
@property
1862+
def graph(self) -> Graph | None:
1863+
"""Return the graph that defines this value.
1864+
1865+
When the value is an input/output/initializer of a graph, the owning graph
1866+
is that graph. When the value is an output of a node, the owning graph is the
1867+
graph that the node belongs to. When the value is not owned by any graph,
1868+
it returns ``None``.
1869+
"""
1870+
if self._graph is not None:
1871+
return self._graph
1872+
if self._producer is not None:
1873+
return self._producer.graph
1874+
return None
1875+
1876+
def _owned_by_graph(self) -> bool:
1877+
"""Return True if the value is owned by a graph."""
1878+
result = self._is_graph_input or self._is_graph_output or self._is_initializer
1879+
if result:
1880+
assert self._graph is not None
1881+
return result
1882+
18491883
def producer(self) -> Node | None:
18501884
"""The node that produces this value.
18511885
18521886
When producer is ``None``, the value does not belong to a node, and is
1853-
typically a graph input or an initializer.
1887+
typically a graph input or an initializer. You can use :meth:`graph``
1888+
to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
1889+
or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
18541890
"""
18551891
return self._producer
18561892

@@ -1986,15 +2022,17 @@ def metadata_props(self) -> dict[str, str]:
19862022
self._metadata_props = {}
19872023
return self._metadata_props
19882024

2025+
def is_graph_input(self) -> bool:
2026+
"""Whether the value is an input of a graph."""
2027+
return self._is_graph_input
2028+
19892029
def is_graph_output(self) -> bool:
19902030
"""Whether the value is an output of a graph."""
1991-
if (producer := self.producer()) is None:
1992-
return False
1993-
if (graph := producer.graph) is None:
1994-
return False
1995-
# Cannot use `in` because __eq__ may be defined by subclasses, even though
1996-
# it is not recommended
1997-
return any(output is self for output in graph.outputs)
2031+
return self._is_graph_output
2032+
2033+
def is_initializer(self) -> bool:
2034+
"""Whether the value is an initializer of a graph."""
2035+
return self._is_initializer
19982036

19992037

20002038
def Input(
@@ -2104,9 +2142,9 @@ def __init__(
21042142
self.name = name
21052143

21062144
# Private fields that are not to be accessed by any other classes
2107-
self._inputs = list(inputs)
2108-
self._outputs = list(outputs)
2109-
self._initializers = {}
2145+
self._inputs = _graph_containers.GraphInputs(self, inputs)
2146+
self._outputs = _graph_containers.GraphOutputs(self, outputs)
2147+
self._initializers = _graph_containers.GraphInitializers(self)
21102148
for initializer in initializers:
21112149
if isinstance(initializer, str):
21122150
raise TypeError(
@@ -2131,15 +2169,15 @@ def __init__(
21312169
self.extend(nodes)
21322170

21332171
@property
2134-
def inputs(self) -> list[Value]:
2172+
def inputs(self) -> MutableSequence[Value]:
21352173
return self._inputs
21362174

21372175
@property
2138-
def outputs(self) -> list[Value]:
2176+
def outputs(self) -> MutableSequence[Value]:
21392177
return self._outputs
21402178

21412179
@property
2142-
def initializers(self) -> dict[str, Value]:
2180+
def initializers(self) -> MutableMapping[str, Value]:
21432181
return self._initializers
21442182

21452183
def register_initializer(self, value: Value) -> None:
@@ -2159,15 +2197,15 @@ def register_initializer(self, value: Value) -> None:
21592197
ValueError: If the initializer is produced by a node.
21602198
ValueError: If the value does not have its ``.const_value`` set.
21612199
"""
2200+
if not value.name:
2201+
raise ValueError(f"Initializer must have a name: {value!r}")
21622202
if value.name in self._initializers:
21632203
if self._initializers[value.name] is not value:
21642204
raise ValueError(
21652205
f"Initializer '{value.name}' is already registered, but"
21662206
" it is not the same object: existing={self._initializers[value.name]!r},"
21672207
f" new={value!r}"
21682208
)
2169-
if not value.name:
2170-
raise ValueError(f"Initializer must have a name: {value!r}")
21712209
if value.producer() is not None:
21722210
raise ValueError(
21732211
f"Value '{value!r}' is produced by a node and cannot be an initializer."
@@ -2858,11 +2896,11 @@ def overload(self, value: str) -> None:
28582896
self._overload = value
28592897

28602898
@property
2861-
def inputs(self) -> list[Value]:
2899+
def inputs(self) -> MutableSequence[Value]:
28622900
return self._graph.inputs
28632901

28642902
@property
2865-
def outputs(self) -> list[Value]:
2903+
def outputs(self) -> MutableSequence[Value]:
28662904
return self._graph.outputs
28672905

28682906
@property

0 commit comments

Comments
 (0)