From db1f921912d25dc84972b9c89eff778bbf479c1a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 29 May 2025 19:25:29 -0700 Subject: [PATCH 1/6] Create the convenience methods i() and o() on Node Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 48 +++++++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 339d7dd4..a04aaee2 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1552,6 +1552,40 @@ def inputs(self, _: Any) -> None: "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." ) + @property + def outputs(self) -> Sequence[Value]: + """The output values of the node. + + The outputs are immutable. To change the outputs, create a new node and + replace the inputs of the using nodes of this node's outputs by calling + :meth:`replace_input_with` on the using nodes of this node's outputs. + """ + return self._outputs + + @outputs.setter + def outputs(self, _: Sequence[Value]) -> None: + raise AttributeError("outputs is immutable. Please create a new node instead.") + + def i(self, index: int = 0) -> Value | None: + """Get the input value at the given index. + + This is a convenience method that is equivalent to `self.inputs[index]`. + + Raises: + IndexError: If the index is out of range. + """ + return self.inputs[index] + + def o(self, index: int = 0) -> Value: + """Get the output value at the given index. + + This is a convenience method that is equivalent to `self.outputs[index]`. + + Raises: + IndexError: If the index is out of range. + """ + return self.outputs[index] + def predecessors(self) -> Sequence[Node]: """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes @@ -1622,20 +1656,6 @@ def append(self, /, nodes: Node | Iterable[Node]) -> None: raise ValueError("The node to append to does not belong to any graph.") self._graph.insert_after(self, nodes) - @property - def outputs(self) -> Sequence[Value]: - """The output values of the node. - - The outputs are immutable. To change the outputs, create a new node and - replace the inputs of the using nodes of this node's outputs by calling - :meth:`replace_input_with` on the using nodes of this node's outputs. - """ - return self._outputs - - @outputs.setter - def outputs(self, _: Sequence[Value]) -> None: - raise AttributeError("outputs is immutable. Please create a new node instead.") - @property def attributes(self) -> OrderedDict[str, Attr]: """The attributes of the node.""" From 528130bb50d53c3cfe0ef193ff647dc5ec321597 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 29 May 2025 19:29:04 -0700 Subject: [PATCH 2/6] Use the method Signed-off-by: Justin Chu --- src/onnx_ir/external_data_test.py | 4 +-- .../clear_metadata_and_docstring_test.py | 4 +-- .../passes/common/constant_manipulation.py | 6 ++--- .../common/constant_manipulation_test.py | 26 +++++++++---------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/onnx_ir/external_data_test.py b/src/onnx_ir/external_data_test.py index f778b513..ffcd5a7a 100644 --- a/src/onnx_ir/external_data_test.py +++ b/src/onnx_ir/external_data_test.py @@ -173,13 +173,13 @@ def _simple_model(self) -> ir.Model: node_1 = ir.Node( "", "Op_1", - inputs=[node_0.outputs[0]], + inputs=[node_0.o()], num_outputs=1, name="node_1", ) graph = ir.Graph( inputs=node_0.inputs, # type: ignore - outputs=[node_1.outputs[0]], + outputs=[node_1.o()], initializers=[ ir.Value(name="tensor1", const_value=tensor1), ir.Value(name="tensor2", const_value=tensor2), diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index 5463cbad..ece0f313 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -30,7 +30,7 @@ def test_pass_with_clear_metadata_and_docstring(self): ) mul_node = ir.node( "Mul", - inputs=[add_node.outputs[0], inputs[1]], + inputs=[add_node.o(), inputs[1]], num_outputs=1, metadata_props={"mul_key": "mul_value"}, doc_string="This is a Mul node", @@ -69,7 +69,7 @@ def test_pass_with_clear_metadata_and_docstring(self): ) sub_node = ir.node( "Sub", - inputs=[function.outputs[0], const_node.outputs[0]], + inputs=[function.o(), const_node.o()], num_outputs=1, metadata_props={"sub_key": "sub_value"}, doc_string="This is a Sub node", diff --git a/src/onnx_ir/passes/common/constant_manipulation.py b/src/onnx_ir/passes/common/constant_manipulation.py index 685017c0..0f1bd710 100644 --- a/src/onnx_ir/passes/common/constant_manipulation.py +++ b/src/onnx_ir/passes/common/constant_manipulation.py @@ -41,7 +41,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: assert node.graph is not None if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue - if node.outputs[0].is_graph_output(): + if node.o().is_graph_output(): logger.debug( "Constant node '%s' is used as output, so it can't be lifted.", node.name ) @@ -54,7 +54,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: continue attr_name, attr_value = next(iter(node.attributes.items())) - initializer_name = node.outputs[0].name + initializer_name = node.o().name assert initializer_name is not None assert isinstance(attr_value, ir.Attr) tensor = self._constant_node_attribute_to_tensor( @@ -73,7 +73,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: assert node.graph is not None node.graph.register_initializer(initializer) # Replace the constant node with the initializer - ir.convenience.replace_all_uses_with(node.outputs[0], initializer) + ir.convenience.replace_all_uses_with(node.o(), initializer) node.graph.remove(node, safe=True) count += 1 logger.debug( diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index afb528a8..d22d5266 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -36,8 +36,8 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers( const_node = ir.node( "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 ) - add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) - mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + add_node = ir.node("Add", inputs=[inputs[0], const_node.o()]) + mul_node = ir.node("Mul", inputs=[add_node.o(), inputs[1]]) model = ir.Model( graph=ir.Graph( @@ -92,10 +92,10 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( ) # then branch adds the constant to the input # else branch multiplies the input by the constant - add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) + add_node = ir.node("Add", inputs=[input_value, then_const_node.o()]) then_graph = ir.Graph( inputs=[], - outputs=[add_node.outputs[0]], + outputs=[add_node.o()], nodes=[then_const_node, add_node], opset_imports={"": 20}, ) @@ -103,10 +103,10 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( else_const_node = ir.node( "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 ) - mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[input_value, else_const_node.o()]) else_graph = ir.Graph( inputs=[], - outputs=[mul_node.outputs[0]], + outputs=[mul_node.o()], nodes=[else_const_node, mul_node], opset_imports={"": 20}, ) @@ -179,14 +179,14 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( num_outputs=1, ) identity_node_constant = ir.node( - "Identity", inputs=[const_node.outputs[0]], num_outputs=1 + "Identity", inputs=[const_node.o()], num_outputs=1 ) identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) model = ir.Model( graph=ir.Graph( inputs=[input_value], - outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]], + outputs=[identity_node_input.o(), identity_node_constant.o()], nodes=[identity_node_input, const_node, identity_node_constant], opset_imports={"": 20}, ), @@ -232,7 +232,7 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self): model = ir.Model( graph=ir.Graph( inputs=[input_value], - outputs=[identity_node_input.outputs[0], const_node.outputs[0]], + outputs=[identity_node_input.o(), const_node.o()], nodes=[identity_node_input, const_node], opset_imports={"": 20}, ), @@ -272,7 +272,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( add_node = ir.node("Add", inputs=[input_value, then_initializer_value]) then_graph = ir.Graph( inputs=[], - outputs=[add_node.outputs[0]], + outputs=[add_node.o()], nodes=[add_node], opset_imports={"": 20}, initializers=[then_initializer_value], @@ -287,7 +287,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) else_graph = ir.Graph( inputs=[], - outputs=[mul_node.outputs[0]], + outputs=[mul_node.o()], nodes=[mul_node], opset_imports={"": 20}, initializers=[else_initializer_value], @@ -358,7 +358,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( # The initializer is also an input. We don't lift it to the main graph # to preserve the graph signature inputs=[then_initializer_value], - outputs=[add_node.outputs[0]], + outputs=[add_node.o()], nodes=[add_node], opset_imports={"": 20}, initializers=[then_initializer_value], @@ -373,7 +373,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) else_graph = ir.Graph( inputs=[], - outputs=[mul_node.outputs[0]], + outputs=[mul_node.o()], nodes=[mul_node], opset_imports={"": 20}, initializers=[else_initializer_value], From 48b0815d19c96fe1270aac46d94c521b27f3bd1c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 29 May 2025 19:29:37 -0700 Subject: [PATCH 3/6] lint Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/constant_manipulation_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index d22d5266..2a1dece5 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -178,9 +178,7 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( attributes={constant_attribute: constant_value}, num_outputs=1, ) - identity_node_constant = ir.node( - "Identity", inputs=[const_node.o()], num_outputs=1 - ) + identity_node_constant = ir.node("Identity", inputs=[const_node.o()], num_outputs=1) identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) model = ir.Model( From 644f67d7718972e25c882b8771c6ab440ed18577 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 29 May 2025 19:34:10 -0700 Subject: [PATCH 4/6] todo Signed-off-by: Justin Chu --- .../passes/common/clear_metadata_and_docstring_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index ece0f313..5a8c1c9c 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -57,6 +57,14 @@ def test_pass_with_clear_metadata_and_docstring(self): domain="my_domain", attributes=[], ) + func_node = ir.node( + "my_function", + inputs=[add_node.o(), inputs[1]], + domain = "my_domain", + metadata_props={"mul_key": "mul_value"}, + doc_string="This is a Mul node", + ) + # TODO(justinchuby): This graph is broken. The output of the function cannot be a input to a node # Create a model with the graph and function constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy())) const_node = ir.node( From 86e70541b74b58f14e5911288a29abf8ca2e9e5a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 30 May 2025 10:32:29 -0700 Subject: [PATCH 5/6] Fix test Signed-off-by: Justin Chu --- .../clear_metadata_and_docstring_test.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index 5a8c1c9c..1fd5a1a4 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -35,20 +35,36 @@ def test_pass_with_clear_metadata_and_docstring(self): metadata_props={"mul_key": "mul_value"}, doc_string="This is a Mul node", ) - func_inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ] function = ir.Function( graph=ir.Graph( name="my_function", - inputs=func_inputs, - outputs=mul_node.outputs, - nodes=[add_node, mul_node], + inputs=[ + input_a := ir.Value( + name="input_a", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 3)), + ), + input_b := ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 3)), + ), + ], + nodes=[ + add_node_func := ir.node( + "Add", + inputs=[input_a, input_b], + metadata_props={"add_key": "add_value"}, + doc_string="This is an Add node", + ), + mul_node_func := ir.node( + "Mul", + inputs=[add_node_func.o(), input_b], + metadata_props={"mul_key": "mul_value"}, + doc_string="This is a Mul node", + ), + ], + outputs=mul_node_func.outputs, opset_imports={"": 20}, doc_string="This is a function docstring", metadata_props={"function_key": "function_value"}, @@ -59,8 +75,8 @@ def test_pass_with_clear_metadata_and_docstring(self): ) func_node = ir.node( "my_function", - inputs=[add_node.o(), inputs[1]], - domain = "my_domain", + inputs=[inputs[0], mul_node.o()], + domain="my_domain", metadata_props={"mul_key": "mul_value"}, doc_string="This is a Mul node", ) @@ -77,7 +93,7 @@ def test_pass_with_clear_metadata_and_docstring(self): ) sub_node = ir.node( "Sub", - inputs=[function.o(), const_node.o()], + inputs=[func_node.o(), const_node.o()], num_outputs=1, metadata_props={"sub_key": "sub_value"}, doc_string="This is a Sub node", From cfd679bee1d282e47c8011e0dc15bf36efda6d5b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 3 Jun 2025 09:52:06 -0700 Subject: [PATCH 6/6] docs Signed-off-by: Justin Chu --- src/onnx_ir/_core.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 8c039b04..12f8be04 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1563,7 +1563,15 @@ def outputs(self, _: Sequence[Value]) -> None: def i(self, index: int = 0) -> Value | None: """Get the input value at the given index. - This is a convenience method that is equivalent to `self.inputs[index]`. + This is a convenience method that is equivalent to ``node.inputs[index]``. + + The following is equivalent:: + + node.inputs[0] == node.i(0) == node.i() # Default index is 0 + node.inputs[index] == node.i(index) + + Returns: + The input value at the given index. Raises: IndexError: If the index is out of range. @@ -1573,7 +1581,15 @@ def i(self, index: int = 0) -> Value | None: def o(self, index: int = 0) -> Value: """Get the output value at the given index. - This is a convenience method that is equivalent to `self.outputs[index]`. + This is a convenience method that is equivalent to ``node.outputs[index]``. + + The following is equivalent:: + + node.outputs[0] == node.o(0) == node.o() # Default index is 0 + node.outputs[index] == node.o(index) + + Returns: + The output value at the given index. Raises: IndexError: If the index is out of range.