Skip to content

Commit ac87a1c

Browse files
authored
Returning choice values in patterns (#2284)
A fix for the use of Choice values (i.e., `OrValue([choice1, choice2])` as a return-value in the pattern (that is, as a root of the pattern). This is a bit complicated to support with the current implementation, which is oriented towards iterating over the nodes in the graph, and matching them against pattern-nodes. This PR provides a limited extension (which handles the case in PR 2277): when the returned choice-values are already covered by the other output values, they can be supported by the existing matcher. Also, on a related note, fix how value-bindings are handled in the pattern-matcher to make it easier to return these values as the outputs of the pattern.
1 parent f04720d commit ac87a1c

File tree

12 files changed

+420
-251
lines changed

12 files changed

+420
-251
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ test-output.xml
4545

4646
# Sphinx documentation
4747
docs/_build/
48+
docs/sg_execution_times.rst
4849

4950
# Jupyter Notebook
5051
.ipynb_checkpoints

docs/ir/tensors.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ To fully support arrays from other frameworks, it is usually a good idea to crea
188188

189189
```{eval-rst}
190190
.. exec_code::
191-
191+
from __future__ import annotations
192192
import ctypes
193193
from typing import Any
194194
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Specifying attributes in the pattern
2+
3+
This section demonstrates the use of attribute values in pattern-based rewriting.
4+
First, write a target pattern and replacement pattern in a similar way to the previous examples.
5+
The example pattern below will match successfully only against Dropout nodes with the
6+
attribute value `training_mode` set to `False`.
7+
The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes
8+
not specified in the pattern. If it is set to `False`, then the node must have only the specified
9+
attribute values, and no other attributes, for a successful match. The default value for this
10+
option is `True`.
11+
12+
```{literalinclude} examples/allow_other_attributes.py
13+
:pyobject: add_pattern
14+
```
15+
16+
```{literalinclude} examples/allow_other_attributes.py
17+
:pyobject: add_replacement
18+
```
19+
20+
```{literalinclude} examples/allow_other_attributes.py
21+
:pyobject: apply_rewrite
22+
```

docs/tutorial/rewriter/commute.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
(heading-target-commute)=
2+
# Utilizing `commute` parameter for pattern-matching
3+
Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure.
4+
5+
![commute](examples/img/erfgelu_03_commute.png){align=center width=500px}
6+
7+
In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched.
8+
9+
![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center}
10+
11+
12+
If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched.
13+
14+
```{literalinclude} examples/erfgelu.py
15+
:pyobject: erf_gelu_pattern
16+
```
17+
18+
```{image} examples/img/erfgelu_06_commute.png
19+
:alt: The resulting graph after matching.
20+
:width: 400px
21+
:align: center
22+
```
23+
24+
Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods.
25+
26+
(heading-target-commute-ruleset)=
27+
28+
## 1. Creating a rule-set with different patterns.
29+
30+
This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example:
31+
32+
```python
33+
from onnxscript.rewriter import pattern
34+
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
35+
```
36+
37+
In order to apply this method to the example above, first create the two separate target patterns as follows:
38+
39+
```{literalinclude} examples/erfgelu.py
40+
:pyobject: erf_gelu_pattern
41+
```
42+
```{literalinclude} examples/erfgelu.py
43+
:pyobject: erf_gelu_pattern_2
44+
```
45+
46+
:::{note}
47+
:name: rule-application-order-matters
48+
49+
When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**.
50+
This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list.
51+
If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur.
52+
:::
53+
54+
55+
Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
56+
57+
```{literalinclude} examples/erfgelu.py
58+
:pyobject: apply_rewrite_with_ruleset
59+
```
60+
61+
## 2. Using the `commute` parameter while creating a rule.
62+
63+
Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter.
64+
65+
```{literalinclude} examples/erfgelu.py
66+
:pyobject: apply_rewrite_with_commute
67+
```
68+
69+
For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows:
70+
71+
![commute](examples/img/erfgelu_07_commute.png){align=center width=300px}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Using the `match_condition` parameter for pattern-matching
2+
3+
This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration.
4+
5+
Let us consider a model which consists of the following pattern.
6+
7+
![target_pattern](examples/img/broadcast_01.png){align=center}
8+
9+
Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following:
10+
11+
1. Input shapes check: `input_a` and `input_b` should be broadcastable
12+
2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)`
13+
14+
If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite.
15+
16+
First, write a target pattern and replacement pattern in a similar way to the first example.
17+
18+
```{literalinclude} examples/broadcast_matmul.py
19+
:pyobject: two_reshapes_matmul_reshape_pattern
20+
```
21+
22+
```{literalinclude} examples/broadcast_matmul.py
23+
:pyobject: matmul_pattern
24+
```
25+
26+
:::{note}
27+
:name: omitting inputs in signature
28+
29+
The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters.
30+
31+
Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature.
32+
:::
33+
34+
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows:
35+
36+
```{literalinclude} examples/broadcast_matmul.py
37+
:pyobject: check_if_not_need_reshape
38+
```
39+
40+
With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite.
41+
42+
```{literalinclude} examples/broadcast_matmul.py
43+
:pyobject: apply_rewrite
44+
```
45+
46+
The final graph with the applied rewrite looks as follows:
47+
48+
![broadcast_rewrite](examples/img/broadcast_02.png){align=center}
49+
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""OR-patterns.
5+
6+
This script shows how to define a rewriting rule based on OR-patterns.
7+
"""
8+
9+
import onnx
10+
11+
import onnxscript
12+
from onnxscript import FLOAT, opset18, script
13+
from onnxscript.rewriter import pattern
14+
15+
####################################
16+
# The target pattern
17+
# =====================
18+
19+
20+
def scaled_matmul(op, x, y, factor):
21+
xy = op.MatMul(x, y)
22+
choice1 = op.Mul(xy, factor)
23+
choice2 = op.Div(xy, factor)
24+
scaled_xy = pattern.OrValue(
25+
[choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"]
26+
)
27+
return op.Relu(scaled_xy)
28+
29+
30+
####################################
31+
# The replacement pattern
32+
# =====================
33+
34+
35+
def scaled_matmul_replacement(op, x, y, factor, op_type):
36+
if op_type == "Mul":
37+
return op.MatMulMulRelu(x, y, factor, _domain="some.domain")
38+
elif op_type == "Div":
39+
return op.MatMulDivRelu(x, y, factor, _domain="some.domain")
40+
else:
41+
raise ValueError(f"Unknown operation type: {op_type}")
42+
43+
44+
####################################
45+
# Rewrite Rule
46+
# =====================
47+
def apply_rewrite(model):
48+
rule = pattern.RewriteRule(
49+
scaled_matmul, # target pattern
50+
scaled_matmul_replacement, # replacement pattern
51+
)
52+
# Create a Rewrite Rule Set
53+
rewrite_rule_set = pattern.RewriteRuleSet([rule])
54+
return onnxscript.rewriter.rewrite(
55+
model,
56+
pattern_rewrite_rules=rewrite_rule_set,
57+
)
58+
59+
60+
@script()
61+
def original_model1(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]:
62+
t1 = opset18.MatMul(A, B)
63+
c = opset18.Constant(value_float=2.0)
64+
t2 = opset18.Mul(t1, c)
65+
t3 = opset18.Relu(t2)
66+
return t3
67+
68+
69+
_model = original_model1.to_model_proto()
70+
onnx.checker.check_model(_model)
71+
72+
_model_with_rewrite = apply_rewrite(_model)
73+
onnx.checker.check_model(_model_with_rewrite)
74+
75+
assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulMulRelu"]
76+
77+
78+
@script()
79+
def original_model2(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]:
80+
t1 = opset18.MatMul(A, B)
81+
c = opset18.Constant(value_float=2.0)
82+
t2 = opset18.Div(t1, c)
83+
t3 = opset18.Relu(t2)
84+
return t3
85+
86+
87+
_model = original_model2.to_model_proto()
88+
onnx.checker.check_model(_model)
89+
90+
_model_with_rewrite = apply_rewrite(_model)
91+
onnx.checker.check_model(_model_with_rewrite)
92+
93+
assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulDivRelu"]

docs/tutorial/rewriter/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Rewriter Tutorials
1+
# Rewriter Tutorial
22

33
```{toctree}
44
rewrite_patterns
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# OR Patterns
2+
3+
*Note* : This feature is work-in-progress.
4+
5+
Consider the following pattern:
6+
7+
```{literalinclude} examples/or_pattern.py
8+
:pyobject: scaled_matmul
9+
```
10+
11+
This pattern will successfully match against the sequence "MatMul => Mul => Relu" as
12+
well as the sequence "MatMul => Div => Relu". The matcher will bind the variable
13+
specified in `tag_var` (`op_type` in the above example) to a value from those
14+
listed in `tag_values` to indicate which of the alternatives was used for a
15+
successful match. We can use this in the rewrite function to determine how
16+
we want to rewrite the matched sub-graph, as illustrated by the following code:
17+
18+
```{literalinclude} examples/or_pattern.py
19+
:pyobject: scaled_matmul_replacement
20+
```

0 commit comments

Comments
 (0)