Skip to content

Commit a18a032

Browse files
committed
[IR] introduce remove_nodes (#2294)
Convenience function to remove a set of nodes.
1 parent cb5942d commit a18a032

File tree

3 files changed

+168
-1
lines changed

3 files changed

+168
-1
lines changed

onnxscript/ir/_convenience/__init__.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"create_value_mapping",
1616
"replace_nodes_and_values",
1717
"insert_nodes_in_value",
18+
"remove_nodes",
1819
]
1920

2021
from typing import Mapping, Sequence, Union
@@ -487,3 +488,65 @@ def insert_nodes_in_value(
487488
# Insert new nodes if there is a graph
488489
graph.extend(new_nodes)
489490
graph.sort()
491+
492+
493+
def remove_nodes(nodes: Sequence[_core.Node]) -> None:
494+
"""Remove a sequence of nodes.
495+
496+
This allows to delete a list of LINKED nodes (over the same context).
497+
498+
For example, suppose we have the following graph::
499+
500+
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output
501+
502+
We want to prune [node_B]::
503+
504+
>>> from onnxscript import ir
505+
>>> input = ir.Input("input")
506+
>>> node_A = ir.node("op_A", [input])
507+
>>> node_B = ir.node("op_B", node_A.outputs)
508+
>>> node_C = ir.node("op_C", node_B.outputs)
509+
>>> # Delete node_B
510+
>>> remove_nodes([node_B])
511+
>>> len(node_A.outputs[0].consumers())
512+
1
513+
>>> node_A.outputs[0].consumers()[0].op_type
514+
'op_C'
515+
>>> len(node_C.inputs)
516+
1
517+
>>> node_C.inputs[0].producer().op_type
518+
'op_A'
519+
>>> node_B.inputs
520+
(None,)
521+
>>> len(node_B.outputs)
522+
1
523+
>>> len(node_B.outputs[0].consumers())
524+
0
525+
526+
Args:
527+
nodes: The nodes to remove.
528+
"""
529+
# Search the unique inputs/outputs in new_nodes, keeping the order.
530+
inputs, outputs = _find_inputs_outputs(nodes)
531+
532+
# Sanity check.
533+
if len(inputs) != len(outputs):
534+
raise ValueError(
535+
f"The number of inputs ({inputs}) and outputs ({outputs}) in nodes must match."
536+
)
537+
538+
# Remove nodes, in several steps:
539+
# 1. Reconnect the users of outputs with inputs
540+
replace_all_uses_with(outputs, inputs)
541+
# 2. Detach nodes for their inputs
542+
for node in nodes:
543+
for i in range(len(node.inputs)):
544+
node.replace_input_with(i, None)
545+
546+
# Update graph if there is one:
547+
if (graph := inputs[-1].graph) is not None:
548+
# Update graph/function outputs if the node generates output
549+
_update_graph_or_function_outputs(graph, outputs, inputs)
550+
551+
# Drop nodes from graph
552+
graph.remove(nodes, safe=True)

onnxscript/ir/_convenience/_init_test.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
from onnxscript import ir
10-
from onnxscript.ir._convenience import insert_nodes_in_value
10+
from onnxscript.ir._convenience import insert_nodes_in_value, remove_nodes
1111

1212

1313
def _create_model(model_text: str) -> ir.Model:
@@ -141,6 +141,108 @@ def test_value_error_for_wrong_number_of_points(self):
141141
with self.assertRaisesRegex(ValueError, "The number of values and outputs"):
142142
insert_nodes_in_value(ir_model.graph[1].outputs, [node])
143143

144+
def test_remove_nodes(self):
145+
# Main graph
146+
input = ir.Input("input")
147+
node_A = ir.node("op_A", [input])
148+
node_B = ir.node("op_B", node_A.outputs)
149+
node_C = ir.node("op_C", node_B.outputs)
150+
151+
# Delete node_B
152+
remove_nodes([node_B])
153+
self.assertEqual(len(node_A.outputs[0].consumers()), 1)
154+
self.assertEqual(node_A.outputs[0].consumers()[0].op_type, "op_C")
155+
self.assertEqual(len(node_C.inputs), 1)
156+
self.assertEqual(node_C.inputs[0].producer().op_type, "op_A")
157+
158+
self.assertEqual((len(node_B.inputs), len(node_B.outputs)), (1, 1))
159+
self.assertEqual(node_B.inputs, (None,))
160+
self.assertEqual(len(node_B.outputs[0].consumers()), 0)
161+
162+
def test_remove_nodes_in_graph(self):
163+
ir_model = _create_model(
164+
"""
165+
<ir_version: 10, opset_import: [ "" : 17]>
166+
agraph (float[N] x) => (float[N] z) {
167+
two = Constant<value_float=2.0>()
168+
a, b = MergeAndSplit(x, two)
169+
z = MergeNode(a, b, two)
170+
}
171+
"""
172+
)
173+
# Sanity check previous to delete nodes
174+
x, two = ir_model.graph.inputs[0], ir_model.graph[0].outputs[0]
175+
self.assertEqual(len(x.consumers()), 1)
176+
self.assertEqual(len(two.consumers()), 2)
177+
178+
# Delete 'MergeAndSplit'
179+
target_node = ir_model.graph[1]
180+
remove_nodes([target_node])
181+
182+
# Check 'MergeNode' has new inputs
183+
a, b, _ = ir_model.graph[-1].inputs
184+
self.assertEqual(a.name, "x")
185+
self.assertEqual(b.name, "two")
186+
187+
# Check x/two consumers have been updated
188+
self.assertEqual(len(x.consumers()), 1)
189+
self.assertEqual(len(two.consumers()), 1)
190+
191+
# Check nodes have been deleted in the graph
192+
self.assertEqual(len(ir_model.graph), 2)
193+
194+
def test_remove_nodes_in_input(self):
195+
ir_model = _create_model(
196+
"""
197+
<ir_version: 10, opset_import: [ "" : 17]>
198+
agraph (float[N] x) => (float[N] z) {
199+
y = Sigmoid(x)
200+
z = Mul(y, y)
201+
}
202+
"""
203+
)
204+
# Remove the node linked to the input
205+
remove_nodes([ir_model.graph[0]])
206+
self.assertEqual(len(ir_model.graph), 1)
207+
self.assertEqual(ir_model.graph[0].op_type, "Mul")
208+
self.assertEqual(ir_model.graph[0].inputs[0].name, "x")
209+
self.assertEqual(ir_model.graph[0].inputs[1].name, "x")
210+
self.assertEqual(ir_model.graph[0].outputs[0].name, "z")
211+
212+
def test_remove_nodes_in_output(self):
213+
ir_model = _create_model(
214+
"""
215+
<ir_version: 10, opset_import: [ "" : 17]>
216+
agraph (float[N] x) => (float[N] z) {
217+
y = Mul(x, x)
218+
z = Sigmoid(y)
219+
}
220+
"""
221+
)
222+
# Remove the node linked to the input
223+
remove_nodes([ir_model.graph[-1]])
224+
self.assertEqual(len(ir_model.graph), 1)
225+
self.assertEqual(ir_model.graph[0].op_type, "Mul")
226+
self.assertEqual(ir_model.graph[0].outputs[0].name, "y")
227+
self.assertEqual(ir_model.graph.outputs[0].name, "y")
228+
229+
def test_remove_nodes_error_for_wrong_number_of_inputs_and_outputs(self):
230+
ir_model = _create_model(
231+
"""
232+
<ir_version: 10, opset_import: [ "" : 17]>
233+
agraph (float[N] x) => (float[N] z) {
234+
two = Constant<value_float=2.0>()
235+
a, b = SplitNode(x)
236+
z = MergeNode(a, b, two)
237+
}
238+
"""
239+
)
240+
with self.assertRaisesRegex(ValueError, "The number of inputs"):
241+
remove_nodes([ir_model.graph[0]])
242+
243+
with self.assertRaisesRegex(ValueError, "The number of inputs"):
244+
remove_nodes([ir_model.graph[1]])
245+
144246

145247
if __name__ == "__main__":
146248
unittest.main()

onnxscript/ir/convenience.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"replace_nodes_and_values",
1212
"create_value_mapping",
1313
"insert_nodes_in_value",
14+
"remove_nodes",
1415
]
1516

1617
from onnxscript.ir._convenience import (
@@ -20,6 +21,7 @@
2021
replace_all_uses_with,
2122
replace_nodes_and_values,
2223
insert_nodes_in_value,
24+
remove_nodes,
2325
)
2426

2527
# NOTE: Do not implement any other functions in this module.

0 commit comments

Comments
 (0)