|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +"""OR-patterns. |
| 5 | +
|
| 6 | +This script shows how to define a rewriting rule based on OR-patterns. |
| 7 | +""" |
| 8 | + |
| 9 | +import onnx |
| 10 | + |
| 11 | +import onnxscript |
| 12 | +from onnxscript import FLOAT, opset18, script |
| 13 | +from onnxscript.rewriter import pattern |
| 14 | + |
| 15 | +#################################### |
| 16 | +# The target pattern |
| 17 | +# ===================== |
| 18 | + |
| 19 | + |
| 20 | +def scaled_matmul(op, x, y, factor): |
| 21 | + xy = op.MatMul(x, y) |
| 22 | + choice1 = op.Mul(xy, factor) |
| 23 | + choice2 = op.Div(xy, factor) |
| 24 | + scaled_xy = pattern.OrValue( |
| 25 | + [choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"] |
| 26 | + ) |
| 27 | + return op.Relu(scaled_xy) |
| 28 | + |
| 29 | + |
| 30 | +#################################### |
| 31 | +# The replacement pattern |
| 32 | +# ===================== |
| 33 | + |
| 34 | + |
| 35 | +def scaled_matmul_replacement(op, x, y, factor, op_type): |
| 36 | + if op_type == "Mul": |
| 37 | + return op.MatMulMulRelu(x, y, factor, _domain="some.domain") |
| 38 | + elif op_type == "Div": |
| 39 | + return op.MatMulDivRelu(x, y, factor, _domain="some.domain") |
| 40 | + else: |
| 41 | + raise ValueError(f"Unknown operation type: {op_type}") |
| 42 | + |
| 43 | + |
| 44 | +#################################### |
| 45 | +# Rewrite Rule |
| 46 | +# ===================== |
| 47 | +def apply_rewrite(model): |
| 48 | + rule = pattern.RewriteRule( |
| 49 | + scaled_matmul, # target pattern |
| 50 | + scaled_matmul_replacement, # replacement pattern |
| 51 | + ) |
| 52 | + # Create a Rewrite Rule Set |
| 53 | + rewrite_rule_set = pattern.RewriteRuleSet([rule]) |
| 54 | + return onnxscript.rewriter.rewrite( |
| 55 | + model, |
| 56 | + pattern_rewrite_rules=rewrite_rule_set, |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +@script() |
| 61 | +def original_model1(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: |
| 62 | + t1 = opset18.MatMul(A, B) |
| 63 | + c = opset18.Constant(value_float=2.0) |
| 64 | + t2 = opset18.Mul(t1, c) |
| 65 | + t3 = opset18.Relu(t2) |
| 66 | + return t3 |
| 67 | + |
| 68 | + |
| 69 | +_model = original_model1.to_model_proto() |
| 70 | +onnx.checker.check_model(_model) |
| 71 | + |
| 72 | +_model_with_rewrite = apply_rewrite(_model) |
| 73 | +onnx.checker.check_model(_model_with_rewrite) |
| 74 | + |
| 75 | +assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulMulRelu"] |
| 76 | + |
| 77 | + |
| 78 | +@script() |
| 79 | +def original_model2(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: |
| 80 | + t1 = opset18.MatMul(A, B) |
| 81 | + c = opset18.Constant(value_float=2.0) |
| 82 | + t2 = opset18.Div(t1, c) |
| 83 | + t3 = opset18.Relu(t2) |
| 84 | + return t3 |
| 85 | + |
| 86 | + |
| 87 | +_model = original_model2.to_model_proto() |
| 88 | +onnx.checker.check_model(_model) |
| 89 | + |
| 90 | +_model_with_rewrite = apply_rewrite(_model) |
| 91 | +onnx.checker.check_model(_model_with_rewrite) |
| 92 | + |
| 93 | +assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulDivRelu"] |
0 commit comments