|
7 | 7 | import unittest |
8 | 8 |
|
9 | 9 | from onnxscript import ir |
10 | | -from onnxscript.ir._convenience import insert_nodes_in_value |
| 10 | +from onnxscript.ir._convenience import insert_nodes_in_value, remove_nodes |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def _create_model(model_text: str) -> ir.Model: |
@@ -141,6 +141,108 @@ def test_value_error_for_wrong_number_of_points(self): |
141 | 141 | with self.assertRaisesRegex(ValueError, "The number of values and outputs"): |
142 | 142 | insert_nodes_in_value(ir_model.graph[1].outputs, [node]) |
143 | 143 |
|
| 144 | + def test_remove_nodes(self): |
| 145 | + # Main graph |
| 146 | + input = ir.Input("input") |
| 147 | + node_A = ir.node("op_A", [input]) |
| 148 | + node_B = ir.node("op_B", node_A.outputs) |
| 149 | + node_C = ir.node("op_C", node_B.outputs) |
| 150 | + |
| 151 | + # Delete node_B |
| 152 | + remove_nodes([node_B]) |
| 153 | + self.assertEqual(len(node_A.outputs[0].consumers()), 1) |
| 154 | + self.assertEqual(node_A.outputs[0].consumers()[0].op_type, "op_C") |
| 155 | + self.assertEqual(len(node_C.inputs), 1) |
| 156 | + self.assertEqual(node_C.inputs[0].producer().op_type, "op_A") |
| 157 | + |
| 158 | + self.assertEqual((len(node_B.inputs), len(node_B.outputs)), (1, 1)) |
| 159 | + self.assertEqual(node_B.inputs, (None,)) |
| 160 | + self.assertEqual(len(node_B.outputs[0].consumers()), 0) |
| 161 | + |
| 162 | + def test_remove_nodes_in_graph(self): |
| 163 | + ir_model = _create_model( |
| 164 | + """ |
| 165 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 166 | + agraph (float[N] x) => (float[N] z) { |
| 167 | + two = Constant<value_float=2.0>() |
| 168 | + a, b = MergeAndSplit(x, two) |
| 169 | + z = MergeNode(a, b, two) |
| 170 | + } |
| 171 | + """ |
| 172 | + ) |
| 173 | + # Sanity check previous to delete nodes |
| 174 | + x, two = ir_model.graph.inputs[0], ir_model.graph[0].outputs[0] |
| 175 | + self.assertEqual(len(x.consumers()), 1) |
| 176 | + self.assertEqual(len(two.consumers()), 2) |
| 177 | + |
| 178 | + # Delete 'MergeAndSplit' |
| 179 | + target_node = ir_model.graph[1] |
| 180 | + remove_nodes([target_node]) |
| 181 | + |
| 182 | + # Check 'MergeNode' has new inputs |
| 183 | + a, b, _ = ir_model.graph[-1].inputs |
| 184 | + self.assertEqual(a.name, "x") |
| 185 | + self.assertEqual(b.name, "two") |
| 186 | + |
| 187 | + # Check x/two consumers have been updated |
| 188 | + self.assertEqual(len(x.consumers()), 1) |
| 189 | + self.assertEqual(len(two.consumers()), 1) |
| 190 | + |
| 191 | + # Check nodes have been deleted in the graph |
| 192 | + self.assertEqual(len(ir_model.graph), 2) |
| 193 | + |
| 194 | + def test_remove_nodes_in_input(self): |
| 195 | + ir_model = _create_model( |
| 196 | + """ |
| 197 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 198 | + agraph (float[N] x) => (float[N] z) { |
| 199 | + y = Sigmoid(x) |
| 200 | + z = Mul(y, y) |
| 201 | + } |
| 202 | + """ |
| 203 | + ) |
| 204 | + # Remove the node linked to the input |
| 205 | + remove_nodes([ir_model.graph[0]]) |
| 206 | + self.assertEqual(len(ir_model.graph), 1) |
| 207 | + self.assertEqual(ir_model.graph[0].op_type, "Mul") |
| 208 | + self.assertEqual(ir_model.graph[0].inputs[0].name, "x") |
| 209 | + self.assertEqual(ir_model.graph[0].inputs[1].name, "x") |
| 210 | + self.assertEqual(ir_model.graph[0].outputs[0].name, "z") |
| 211 | + |
| 212 | + def test_remove_nodes_in_output(self): |
| 213 | + ir_model = _create_model( |
| 214 | + """ |
| 215 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 216 | + agraph (float[N] x) => (float[N] z) { |
| 217 | + y = Mul(x, x) |
| 218 | + z = Sigmoid(y) |
| 219 | + } |
| 220 | + """ |
| 221 | + ) |
| 222 | + # Remove the node linked to the input |
| 223 | + remove_nodes([ir_model.graph[-1]]) |
| 224 | + self.assertEqual(len(ir_model.graph), 1) |
| 225 | + self.assertEqual(ir_model.graph[0].op_type, "Mul") |
| 226 | + self.assertEqual(ir_model.graph[0].outputs[0].name, "y") |
| 227 | + self.assertEqual(ir_model.graph.outputs[0].name, "y") |
| 228 | + |
| 229 | + def test_remove_nodes_error_for_wrong_number_of_inputs_and_outputs(self): |
| 230 | + ir_model = _create_model( |
| 231 | + """ |
| 232 | + <ir_version: 10, opset_import: [ "" : 17]> |
| 233 | + agraph (float[N] x) => (float[N] z) { |
| 234 | + two = Constant<value_float=2.0>() |
| 235 | + a, b = SplitNode(x) |
| 236 | + z = MergeNode(a, b, two) |
| 237 | + } |
| 238 | + """ |
| 239 | + ) |
| 240 | + with self.assertRaisesRegex(ValueError, "The number of inputs"): |
| 241 | + remove_nodes([ir_model.graph[0]]) |
| 242 | + |
| 243 | + with self.assertRaisesRegex(ValueError, "The number of inputs"): |
| 244 | + remove_nodes([ir_model.graph[1]]) |
| 245 | + |
144 | 246 |
|
145 | 247 | if __name__ == "__main__": |
146 | 248 | unittest.main() |
0 commit comments