Skip to content

Commit e72efe4

Browse files
committed
[symbol_structure]: add mir op_class, add mir symbol_pass
1 parent dea98a9 commit e72efe4

File tree

6 files changed

+452
-30
lines changed

6 files changed

+452
-30
lines changed

python/mrt/mir/opclass.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import typing
2+
from dataclasses import dataclass
3+
from . import opns
4+
from . import symbol
5+
6+
MRT_OP_MAP: typing.Dict[str, typing.Any] = {}
7+
8+
#def _register_op_map_(op_name: str, clss:typing.Any=None):
9+
# if len(op_name)>0 and clss!=None:
10+
# if op_name not in MRT_OP_MAP:
11+
# MRT_OP_MAP[op_name] = clss
12+
# return MRT_OP_MAP
13+
14+
def _register_op_map(op_name: str): #, clss:typing.Any=None):
15+
def _wrapper(clss: typing.Any=None):
16+
if len(op_name)>0 and clss!=None:
17+
if op_name not in MRT_OP_MAP:
18+
MRT_OP_MAP[op_name] = clss
19+
return clss
20+
return _wrapper
21+
22+
@_register_op_map(opns.CONV2D)
23+
@dataclass(init=False)
24+
class Conv2D(symbol.Symbol):
25+
26+
op_name = opns.CONV2D
27+
28+
@property
29+
def strides(self) -> typing.Tuple[int, int]:
30+
default_val = (1,1)
31+
return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val
32+
33+
@property
34+
def padding(self) -> typing.Tuple[int, int, int, int]:
35+
default_val = (0,0,0,0)
36+
return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val
37+
38+
@property
39+
def dilation(self) -> typing.Tuple[int, int]:
40+
default_val = (1,1)
41+
return self._ if self._ else self.attrs[''] if '' in self.attrs else default_val
42+
43+
@property
44+
def kernel_size(self) -> typing.Tuple[int, int]:
45+
default_val = (3,3)
46+
return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val
47+
48+
def __init__(self, name:str, **kwargs):
49+
self.name = name
50+
self.args = kwargs.pop('args', [])
51+
self.attrs = kwargs.pop('attrs', {})
52+
self.extra_attrs = {}
53+
54+
# TODO: what if strides not in attrs?
55+
self._strides = self.attrs['strides']
56+
if 'padding' in self.attrs:
57+
self._padding = self.attrs['padding']
58+
if 'dilation' in self.attrs:
59+
self._dilation = self.attrs['dilation']
60+
if 'kernel_size' in self.attrs:
61+
self._kernel_size = self.attrs['kernel_size']
62+
63+
64+
@_register_op_map(opns.DROP_OUT)
65+
@dataclass(init=False)
66+
class Dropout(symbol.Symbol):
67+
68+
op_name = opns.DROP_OUT
69+
70+
@property
71+
def rate(self) -> float:
72+
default_val = 0.0
73+
return self._rate if self._rate else self.attrs['rate'] if 'rate' in self.attrs else default_val
74+
75+
def __init__(self, name:str, **kwargs):
76+
self.name = name
77+
self.args = kwargs.pop('args', [])
78+
self.attrs = kwargs.pop('attrs', {})
79+
self.extra_attrs = {}
80+
81+
self._rate = self.attrs['rate']
82+
83+
@_register_op_map(opns.CLIP)
84+
@dataclass(init=False)
85+
class Clip(symbol.Symbol):
86+
87+
op_name = opns.CLIP
88+
89+
@property
90+
def min(self) -> float:
91+
default_val = None
92+
return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val
93+
94+
@property
95+
def max(self) -> float:
96+
default_val = None
97+
return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val
98+
99+
def __init__(self, name:str, **kwargs):
100+
self.name = name
101+
self.args = kwargs.pop('args', [])
102+
self.attrs = kwargs.pop('attrs', {})
103+
self.extra_attrs = {}
104+
105+
self._min = self.attrs['min']
106+
self._max = self.attrs['max']
107+
108+
109+
@_register_op_map(opns.BATCH_NORM)
110+
@dataclass(init=False)
111+
class BatchNorm(symbol.Symbol):
112+
113+
op_name = opns.BATCH_NORM
114+
115+
@property
116+
def axis(self) -> float:
117+
default_val = 1
118+
return self._axis if self._axis else self.attrs['axis'] if 'axis' in self.attrs else default_val
119+
120+
@property
121+
def epsilon(self) -> float:
122+
default_val = 1e-5
123+
return self._epsilon if self._epsilon else self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val
124+
125+
@property
126+
def center(self) -> float:
127+
default_val = True
128+
return self._center if self._center else self.attrs['center'] if 'center' in self.attrs else default_val
129+
130+
@property
131+
def scale(self) -> float:
132+
default_val = True
133+
return self._scale if self._scale else self.attrs['scale'] if 'scale' in self.attrs else default_val
134+
135+
def __init__(self, name:str, **kwargs):
136+
self.name = name
137+
self.args = kwargs.pop('args', [])
138+
self.attrs = kwargs.pop('attrs', {})
139+
self.extra_attrs = {}
140+
141+
self._axis = self.attrs['axis']
142+
self._epsilon = self.attrs['epsilon']
143+
self._center = self.attrs['center']
144+
self._scale = self.attrs['scale']
145+
146+
@_register_op_map(opns.DENSE)
147+
@dataclass(init=False)
148+
class Dense(symbol.Symbol):
149+
150+
op_name = opns.DENSE
151+
152+
def __init__(self, name:str, **kwargs):
153+
self.name = name
154+
self.args = kwargs.pop('args', [])
155+
self.attrs = kwargs.pop('attrs', {})
156+
self.extra_attrs = {}
157+
158+
@_register_op_map(opns.TUPLE_GET_ITEM)
159+
@dataclass(init=False)
160+
class TupleGetItem(symbol.Symbol):
161+
162+
op_name = opns.TUPLE_GET_ITEM
163+
164+
@property
165+
def index(self) -> float:
166+
default_val = 0
167+
return self._index if self._index else self.attrs['index'] if 'index' in self.attrs else default_val
168+
169+
def __init__(self, name:str, **kwargs):
170+
self.name = name
171+
self.args = kwargs.pop('args', [])
172+
self.attrs = kwargs.pop('attrs', {})
173+
self.extra_attrs = {}
174+
175+
self._index = self.attrs['index']
176+
177+
@_register_op_map(opns.MUL)
178+
@dataclass(init=False)
179+
class Multiply(symbol.Symbol):
180+
181+
op_name = opns.MUL
182+
183+
def __init__(self, name:str, **kwargs):
184+
self.name = name
185+
self.args = kwargs.pop('args', [])
186+
self.attrs = kwargs.pop('attrs', {})
187+
self.extra_attrs = {}
188+

python/mrt/mir/opns.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
""" MRT operator names """
2+
import typing
3+
4+
MRT_OP_SET = set()
5+
def _register_op_list(*op_names: typing.List[str]):
6+
for op_name in op_names:
7+
if op_name not in MRT_OP_SET:
8+
MRT_OP_SET.add(op_name)
29

310
VAR = "var"
11+
_register_op_list(VAR)
412

513
DROP_OUT = "nn.dropout"
614
CONV2D = "nn.conv2d"
@@ -14,22 +22,29 @@
1422
ADAPTIVE_AVG_POOL2D = "nn.adaptive_avg_pool2d"
1523
AVG_POOL2D = "nn.avg_pool2d"
1624
MAX_POOL2D = "nn.max_pool2d"
25+
_register_op_list(DROP_OUT, CONV2D, DENSE, BATCH_NORM, RELU,
26+
HARDTANH, SILU, LEAKY_RELU, ADAPTIVE_AVG_POOL2D,
27+
AVG_POOL2D, MAX_POOL2D)
1728

1829
SOFTMAX = "nn.softmax"
1930
LOG_SOFTMAX = "nn.log_softmax"
31+
_register_op_list(SOFTMAX, LOG_SOFTMAX)
2032

2133
EXP = "exp"
2234
SIGMOID = "sigmoid"
35+
_register_op_list(EXP, SIGMOID)
2336

2437
SUM = "sum"
2538
MEAN = "mean"
2639
MAX_AXIS = "max"
2740
MAXIMUM = "maximum"
2841
MINIMUM = "minimum"
42+
_register_op_list(SUM, MEAN, MAX_AXIS, MAXIMUM, MINIMUM)
2943

3044
# =========== NON-CALC ops ===============
3145
TUPLE = "Tuple"
3246
TUPLE_GET_ITEM = "TupleGetItem"
47+
_register_op_list(TUPLE, TUPLE_GET_ITEM)
3348

3449
REPEAT = "repeat"
3550
SQUEEZE = "squeeze"
@@ -40,16 +55,20 @@
4055
SPLIT = "split"
4156
TRANSPOSE = "transpose"
4257
BROADCAST_TO = "broadcast_to"
58+
_register_op_list(REPEAT, SQUEEZE, FLATTEN, BATCH_FLATTEN, RESHAPE,
59+
CONCAT, SPLIT, TRANSPOSE, BROADCAST_TO, )
4360

4461
EXPAND_DIMS = "expand_dims"
4562
TILE = "tile"
63+
_register_op_list(EXPAND_DIMS, TILE)
4664

4765
WHERE = "where"
4866
GREATER = "greater"
4967
STRIDED_SLICE = "strided_slice"
5068
SLICE_LIKE = "slice_like"
5169
GET_VALID_COUNT = "vision.get_valid_counts"
5270
NON_MAX_SUPRESSION = "vision.non_max_suppression"
71+
_register_op_list(WHERE, GREATER, STRIDED_SLICE, SLICE_LIKE, GET_VALID_COUNT, NON_MAX_SUPRESSION)
5372

5473
# relax clip attrs from a_min/a_max to min/max
5574
CLIP = "clip"
@@ -58,11 +77,14 @@
5877
# relax support astype instead of cast
5978
AS_TYPE = "astype"
6079
# CAST = "cast"
80+
_register_op_list(CLIP, CEIL, RIGHT_SHIFT, AS_TYPE)
6181

6282
ADV_INDEX = "adv_index"
83+
_register_op_list(ADV_INDEX)
6384

6485
CALL_TIR = "call_tir"
6586
CALL_DPS_PACKED = "call_dps_packed"
87+
_register_op_list(CALL_TIR, CALL_DPS_PACKED)
6688

6789
# ======= binary ops =============
6890

@@ -71,6 +93,7 @@
7193
MUL = "multiply"
7294
MATMUL = "matmul"
7395
DIV = "divide"
96+
_register_op_list(ADD, SUB, MUL, MATMUL, DIV)
7497

7598
# ======= unary ops ==============
7699

@@ -81,14 +104,17 @@
81104
POW = "pow"
82105

83106
PASS = "pass"
107+
_register_op_list(NEGATIVE, ABS, LOG, SQRT, POW, PASS)
84108
# ======= auto generate op =========
85109
ARANGE = "arange"
86110
ZEROS_LIKE = "zeros_like"
87111
ONES_LIKE = "ones_like"
112+
_register_op_list(ARANGE, ZEROS_LIKE, ONES_LIKE)
88113

89114
# ======= control flow op ===========
90115
IF = "if"
91116
ARGWHERE = "argwhere"
117+
_register_op_list(IF, ARGWHERE)
92118

93119
# ======= mrt requant op ==========
94120
REQUANT = "mrt.requant"
@@ -98,4 +124,9 @@
98124
""" right shift precision clip """
99125
LUT = "mrt.lut"
100126
""" look up table, equals adv_index in tvm """
127+
_register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT)
128+
101129

130+
def Opname2Funcname(op_name: str):
131+
return op_name.replace('.', '_')
132+
#print('MRT_OP_SET:', MRT_OP_SET)

python/mrt/mir/symbol.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111

1212
# from . import config
1313
# from .utils import *
14-
# from .types import *
15-
from .opns import *
14+
from . import opns
1615

1716
__ALL__ = [
1817
"Symbol",
@@ -277,34 +276,6 @@ def __hash__(self) -> int:
277276
def hash(self) -> int:
278277
return hash(str(self))
279278

280-
# class Convolution2D(Symbol):
281-
# strides: typing.Tuple[int, int]
282-
283-
# class Dropout(Symbol):
284-
# eps: float = 1e-5
285-
286-
# class Pass:
287-
# symbol: Symbol
288-
289-
# def visit(self, op: Symbol):
290-
# env: typing.Dict[Symbol, Symbol] = {}
291-
# for sym in sym2list(self.symbol):
292-
# out = getattr(self, f"visit_{op.op_name}")(op) or op
293-
# assert isinstance(sym, Symbol)
294-
# env[sym] = out
295-
# return env[op]
296-
297-
# def _default_visit_op(op):
298-
# return op
299-
300-
# for op in op_list:
301-
# setattr(Pass, f"visit_{op.op_name}", _default_visit_op)
302-
303-
# class FuseDropoutPass(Pass):
304-
# def visit_dropout(self, op: Dropout):
305-
# op.eps
306-
# return op.args[0]
307-
308279
def _topo_sort(symbol: Symbol, sym_list: typing.List[Symbol]):
309280
assert isinstance(symbol, Symbol), \
310281
f"({type(symbol).__name__}){str(symbol)}"

0 commit comments

Comments
 (0)