From e72efe4694b942f02f66f37a3e1267d26d65429d Mon Sep 17 00:00:00 2001 From: corlfj Date: Thu, 11 Sep 2025 19:27:04 +0800 Subject: [PATCH 1/7] [symbol_structure]: add mir op_class, add mir symbol_pass --- python/mrt/mir/opclass.py | 188 ++++++++++++++++++++++++++++++++++ python/mrt/mir/opns.py | 31 ++++++ python/mrt/mir/symbol.py | 31 +----- python/mrt/mir/symbolpass.py | 84 +++++++++++++++ tests/mir/test.op_create.py | 60 +++++++++++ tests/mir/test.symbol_pass.py | 88 ++++++++++++++++ 6 files changed, 452 insertions(+), 30 deletions(-) create mode 100644 python/mrt/mir/opclass.py create mode 100644 python/mrt/mir/symbolpass.py create mode 100644 tests/mir/test.op_create.py create mode 100644 tests/mir/test.symbol_pass.py diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py new file mode 100644 index 0000000..3d5433a --- /dev/null +++ b/python/mrt/mir/opclass.py @@ -0,0 +1,188 @@ +import typing +from dataclasses import dataclass +from . import opns +from . import symbol + +MRT_OP_MAP: typing.Dict[str, typing.Any] = {} + +#def _register_op_map_(op_name: str, clss:typing.Any=None): +# if len(op_name)>0 and clss!=None: +# if op_name not in MRT_OP_MAP: +# MRT_OP_MAP[op_name] = clss +# return MRT_OP_MAP + +def _register_op_map(op_name: str): #, clss:typing.Any=None): + def _wrapper(clss: typing.Any=None): + if len(op_name)>0 and clss!=None: + if op_name not in MRT_OP_MAP: + MRT_OP_MAP[op_name] = clss + return clss + return _wrapper + +@_register_op_map(opns.CONV2D) +@dataclass(init=False) +class Conv2D(symbol.Symbol): + + op_name = opns.CONV2D + + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (1,1) + return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val + + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0,0,0,0) + return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val + + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1,1) + return self._ if self._ else self.attrs[''] if '' in self.attrs else default_val + + @property + def kernel_size(self) -> typing.Tuple[int, int]: + default_val = (3,3) + return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + # TODO: what if strides not in attrs? + self._strides = self.attrs['strides'] + if 'padding' in self.attrs: + self._padding = self.attrs['padding'] + if 'dilation' in self.attrs: + self._dilation = self.attrs['dilation'] + if 'kernel_size' in self.attrs: + self._kernel_size = self.attrs['kernel_size'] + + +@_register_op_map(opns.DROP_OUT) +@dataclass(init=False) +class Dropout(symbol.Symbol): + + op_name = opns.DROP_OUT + + @property + def rate(self) -> float: + default_val = 0.0 + return self._rate if self._rate else self.attrs['rate'] if 'rate' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._rate = self.attrs['rate'] + +@_register_op_map(opns.CLIP) +@dataclass(init=False) +class Clip(symbol.Symbol): + + op_name = opns.CLIP + + @property + def min(self) -> float: + default_val = None + return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val + + @property + def max(self) -> float: + default_val = None + return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._min = self.attrs['min'] + self._max = self.attrs['max'] + + +@_register_op_map(opns.BATCH_NORM) +@dataclass(init=False) +class BatchNorm(symbol.Symbol): + + op_name = opns.BATCH_NORM + + @property + def axis(self) -> float: + default_val = 1 + return self._axis if self._axis else self.attrs['axis'] if 'axis' in self.attrs else default_val + + @property + def epsilon(self) -> float: + default_val = 1e-5 + return self._epsilon if self._epsilon else self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val + + @property + def center(self) -> float: + default_val = True + return self._center if self._center else self.attrs['center'] if 'center' in self.attrs else default_val + + @property + def scale(self) -> float: + default_val = True + return self._scale if self._scale else self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._axis = self.attrs['axis'] + self._epsilon = self.attrs['epsilon'] + self._center = self.attrs['center'] + self._scale = self.attrs['scale'] + +@_register_op_map(opns.DENSE) +@dataclass(init=False) +class Dense(symbol.Symbol): + + op_name = opns.DENSE + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + +@_register_op_map(opns.TUPLE_GET_ITEM) +@dataclass(init=False) +class TupleGetItem(symbol.Symbol): + + op_name = opns.TUPLE_GET_ITEM + + @property + def index(self) -> float: + default_val = 0 + return self._index if self._index else self.attrs['index'] if 'index' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._index = self.attrs['index'] + +@_register_op_map(opns.MUL) +@dataclass(init=False) +class Multiply(symbol.Symbol): + + op_name = opns.MUL + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index 5b92822..be2f823 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -1,6 +1,14 @@ """ MRT operator names """ +import typing + +MRT_OP_SET = set() +def _register_op_list(*op_names: typing.List[str]): + for op_name in op_names: + if op_name not in MRT_OP_SET: + MRT_OP_SET.add(op_name) VAR = "var" +_register_op_list(VAR) DROP_OUT = "nn.dropout" CONV2D = "nn.conv2d" @@ -14,22 +22,29 @@ ADAPTIVE_AVG_POOL2D = "nn.adaptive_avg_pool2d" AVG_POOL2D = "nn.avg_pool2d" MAX_POOL2D = "nn.max_pool2d" +_register_op_list(DROP_OUT, CONV2D, DENSE, BATCH_NORM, RELU, + HARDTANH, SILU, LEAKY_RELU, ADAPTIVE_AVG_POOL2D, + AVG_POOL2D, MAX_POOL2D) SOFTMAX = "nn.softmax" LOG_SOFTMAX = "nn.log_softmax" +_register_op_list(SOFTMAX, LOG_SOFTMAX) EXP = "exp" SIGMOID = "sigmoid" +_register_op_list(EXP, SIGMOID) SUM = "sum" MEAN = "mean" MAX_AXIS = "max" MAXIMUM = "maximum" MINIMUM = "minimum" +_register_op_list(SUM, MEAN, MAX_AXIS, MAXIMUM, MINIMUM) # =========== NON-CALC ops =============== TUPLE = "Tuple" TUPLE_GET_ITEM = "TupleGetItem" +_register_op_list(TUPLE, TUPLE_GET_ITEM) REPEAT = "repeat" SQUEEZE = "squeeze" @@ -40,9 +55,12 @@ SPLIT = "split" TRANSPOSE = "transpose" BROADCAST_TO = "broadcast_to" +_register_op_list(REPEAT, SQUEEZE, FLATTEN, BATCH_FLATTEN, RESHAPE, + CONCAT, SPLIT, TRANSPOSE, BROADCAST_TO, ) EXPAND_DIMS = "expand_dims" TILE = "tile" +_register_op_list(EXPAND_DIMS, TILE) WHERE = "where" GREATER = "greater" @@ -50,6 +68,7 @@ SLICE_LIKE = "slice_like" GET_VALID_COUNT = "vision.get_valid_counts" NON_MAX_SUPRESSION = "vision.non_max_suppression" +_register_op_list(WHERE, GREATER, STRIDED_SLICE, SLICE_LIKE, GET_VALID_COUNT, NON_MAX_SUPRESSION) # relax clip attrs from a_min/a_max to min/max CLIP = "clip" @@ -58,11 +77,14 @@ # relax support astype instead of cast AS_TYPE = "astype" # CAST = "cast" +_register_op_list(CLIP, CEIL, RIGHT_SHIFT, AS_TYPE) ADV_INDEX = "adv_index" +_register_op_list(ADV_INDEX) CALL_TIR = "call_tir" CALL_DPS_PACKED = "call_dps_packed" +_register_op_list(CALL_TIR, CALL_DPS_PACKED) # ======= binary ops ============= @@ -71,6 +93,7 @@ MUL = "multiply" MATMUL = "matmul" DIV = "divide" +_register_op_list(ADD, SUB, MUL, MATMUL, DIV) # ======= unary ops ============== @@ -81,14 +104,17 @@ POW = "pow" PASS = "pass" +_register_op_list(NEGATIVE, ABS, LOG, SQRT, POW, PASS) # ======= auto generate op ========= ARANGE = "arange" ZEROS_LIKE = "zeros_like" ONES_LIKE = "ones_like" +_register_op_list(ARANGE, ZEROS_LIKE, ONES_LIKE) # ======= control flow op =========== IF = "if" ARGWHERE = "argwhere" +_register_op_list(IF, ARGWHERE) # ======= mrt requant op ========== REQUANT = "mrt.requant" @@ -98,4 +124,9 @@ """ right shift precision clip """ LUT = "mrt.lut" """ look up table, equals adv_index in tvm """ +_register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT) + +def Opname2Funcname(op_name: str): + return op_name.replace('.', '_') +#print('MRT_OP_SET:', MRT_OP_SET) diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 5c97cee..e27adcd 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -11,8 +11,7 @@ # from . import config # from .utils import * -# from .types import * -from .opns import * +from . import opns __ALL__ = [ "Symbol", @@ -277,34 +276,6 @@ def __hash__(self) -> int: def hash(self) -> int: return hash(str(self)) -# class Convolution2D(Symbol): -# strides: typing.Tuple[int, int] - -# class Dropout(Symbol): -# eps: float = 1e-5 - -# class Pass: -# symbol: Symbol - -# def visit(self, op: Symbol): -# env: typing.Dict[Symbol, Symbol] = {} -# for sym in sym2list(self.symbol): -# out = getattr(self, f"visit_{op.op_name}")(op) or op -# assert isinstance(sym, Symbol) -# env[sym] = out -# return env[op] - -# def _default_visit_op(op): -# return op - -# for op in op_list: -# setattr(Pass, f"visit_{op.op_name}", _default_visit_op) - -# class FuseDropoutPass(Pass): -# def visit_dropout(self, op: Dropout): -# op.eps -# return op.args[0] - def _topo_sort(symbol: Symbol, sym_list: typing.List[Symbol]): assert isinstance(symbol, Symbol), \ f"({type(symbol).__name__}){str(symbol)}" diff --git a/python/mrt/mir/symbolpass.py b/python/mrt/mir/symbolpass.py new file mode 100644 index 0000000..883af38 --- /dev/null +++ b/python/mrt/mir/symbolpass.py @@ -0,0 +1,84 @@ +from __future__ import annotations +import typing + +from functools import wraps +from dataclasses import dataclass, fields + +import mrt +from mrt.common import config +from mrt.common.utils import * +from mrt.common.types import * + +from . import opns, opclass +from . import symbol as _symbol + + +# mrt op visits +class SymbolPass: + symbol: _symbol.Symbol + params: ParametersT + + def __init__(self, symbol: _symbol.Symbol, params: ParametersT): + self.symbol = symbol + self.params = params + + def is_param(self, symbol: _symbol.Symbol) -> bool: + return symbol.op_name == opns.VAR and symbol.name in self.params + + def visit(self) -> _symbol.Symbol: + env: typing.Dict[str, _symbol.Symbol] = {} + for sym in _symbol.sym2list(self.symbol): + assert sym.name not in env, f'{sym.name} NotIn env!' + + # Updating args as passed symbol in env_dict + sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) + assert isinstance(sym, _symbol.Symbol), sym + + if sym.op_name == opns.DROP_OUT: + #print('ddrroopped_out', getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym) + pass + out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym + assert isinstance(out, _symbol.Symbol), out + env[sym.name] = out + return env[self.symbol.name] + + def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: + return op + + +# register mrt op default_visit +for op_name in opns.MRT_OP_SET: + funcSuffix = opns.Opname2Funcname(op_name) + setattr(SymbolPass, f"visit_{funcSuffix}", SymbolPass._default_visit_op) + #print(f"visit_, {op_name} => {funcSuffix}", getattr(SymbolPass, f"visit_{funcSuffix}")) + + +# mrt symbol pass +class FuseDropoutPass(SymbolPass): + def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: + # make sure op fit again + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + + +class FuseDividePass(SymbolPass): + def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.DIVIDE: + argA = self.args[0] + argB = self.args[1] + assert self.is_param(argB), f'NotParam: {argB}' + # TODO: fixit + #argB = argB.from_np_data(1. / argB.numpy()) + return opclass.Multiply(sym.name, {'args':[argA, argB]}) + return sym + + +class FuseTupleGetItemPass(SymbolPass): + def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.TUPLE_GET_ITEM: + sym_ : opclass.TupleGetItem = sym + assert sym_.index == 0 + return sym_.args[0] + return sym + diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py new file mode 100644 index 0000000..a36e16c --- /dev/null +++ b/tests/mir/test.op_create.py @@ -0,0 +1,60 @@ +""" +Test script for Alexnet PyTorch to MRT conversion. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass + + +def test_create_conv2d_op(): + #class CONV2D(Symbol): + # strides: typing.Tuple[int, int] = (1,1) + # padding: typing.Optional[typing.Tuple[int, int, int, int]] = (0,0,0,0) + # create mrt op symbol, def func + print('mrt Conv2D Op Class:', opclass.Conv2D) + conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_b',args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) + assert isinstance(conv2d_b, sx.Symbol), 'not!con2d_b symbol' + assert isinstance(conv2d_b, opclass.Conv2D), 'not!2 -con2d_b' + + # attrs hint + assert conv2d_b.args != None + assert conv2d_b.attrs != None + assert conv2d_b.strides != None + + print(f'Got {conv2d_b.name} strides: {conv2d_b.strides}') + print(f'Got {conv2d_b.name} padding: {conv2d_b.padding}') + print(f'Show {conv2d_b.name} {conv2d_b}') + return True + + +# TODO: +#def test_create_symbol_graph(): + +if __name__ == "__main__": + print('MRT_OP_SET as:', opns.MRT_OP_SET) + assert len(opns.MRT_OP_SET) > 0 + + print('MRT_OP_MAP Class as:', opclass.MRT_OP_MAP) + assert len(opclass.MRT_OP_MAP) > 0 + assert opns.CONV2D in opclass.MRT_OP_MAP + + rltflag = test_create_conv2d_op() + print("\n" + "="*60 + "\n") + print('Passed Test!' if rltflag else 'Test Failed!') + print("\n" + "="*60 + "\n") + diff --git a/tests/mir/test.symbol_pass.py b/tests/mir/test.symbol_pass.py new file mode 100644 index 0000000..7d109ef --- /dev/null +++ b/tests/mir/test.symbol_pass.py @@ -0,0 +1,88 @@ +""" +Test script for MRT Alexnet FuseDropoutPass. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import symbolpass + +def _get_alexnet_model(): + """Get Alexnet MRT Model""" + + # Load pre-trained Alexnet + model = models.alexnet(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Alexnet to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + +def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDropout Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + # print(sym) + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt>0, f'original model dropout op cnt {dropout_op_cnt} == zero!' + + # init FuseDropout Passer and execute visit + tfs : symbolpass.FuseDropoutPass = symbolpass.FuseDropoutPass(symbol, {}) + #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) + symbol_passed = tfs.visit() + + print('\n=== After FuseDropout Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + # print(sym) + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af==0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + #for sym in symdict: + # print(sym, symdict[sym]) + + #print('\n=== Back To SymList ===') + #rltlist = sx.sym2list(symdict[symbol.name]) + + return True + +if __name__ == "__main__": + + print("=== Testing SymbolPass ===") + mrt_graph, mrt_params = _get_alexnet_model() + + print("Testing FuseDropoutPass for Model AlexNet") + rltflag = test_SymbolPass_FuseDropout(mrt_graph, mrt_params) + + print("\n" + "="*60 + "\n") + print('Passed Test!' if rltflag else 'Test Failed!') + print("\n" + "="*60 + "\n") + From d830540788fd21f926a1da3c3cc4e57ef9264ddd Mon Sep 17 00:00:00 2001 From: corlfj Date: Fri, 12 Sep 2025 10:24:56 +0800 Subject: [PATCH 2/7] [mir]: fix last commit --- python/mrt/mir/opclass.py | 50 +++++++++++++++++++++--------------- python/mrt/mir/opns.py | 6 ++--- python/mrt/mir/symbol.py | 2 +- python/mrt/mir/symbolpass.py | 6 ++--- tests/mir/test.op_create.py | 25 +++++++++++------- 5 files changed, 53 insertions(+), 36 deletions(-) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 3d5433a..a8fb932 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,4 +1,5 @@ import typing +import numpy as np from dataclasses import dataclass from . import opns from . import symbol @@ -19,7 +20,6 @@ def _wrapper(clss: typing.Any=None): return clss return _wrapper -@_register_op_map(opns.CONV2D) @dataclass(init=False) class Conv2D(symbol.Symbol): @@ -28,31 +28,40 @@ class Conv2D(symbol.Symbol): @property def strides(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val + return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val @property def padding(self) -> typing.Tuple[int, int, int, int]: default_val = (0,0,0,0) - return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val + return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val @property def dilation(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._ if self._ else self.attrs[''] if '' in self.attrs else default_val + return self._dilation if self._dilation else self.attrs['dilation'] if 'dilation' in self.attrs else default_val @property def kernel_size(self) -> typing.Tuple[int, int]: default_val = (3,3) - return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} + return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + + def __init__(self, name_or_inst: typing.Union[str, symbol.Symbol], **kwargs): + assert isinstance(name_or_inst, str) or isinstance(name_or_inst, symbol.Symbol) + if isinstance(name_or_inst, str): + self.name = name_or_inst + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + else: + # clone mode + self.name = name_or_inst.name + self.args = [a for a in name_or_inst.args] + self.attrs = {k: v for k, v in name_or_inst.attrs.items()} + self.extra_attrs = {k: v for k, v in name_or_inst.extra_attrs.items()} # TODO: what if strides not in attrs? - self._strides = self.attrs['strides'] + if 'strides' in self.attrs: + self._strides = self.attrs['strides'] if 'padding' in self.attrs: self._padding = self.attrs['padding'] if 'dilation' in self.attrs: @@ -61,7 +70,6 @@ def __init__(self, name:str, **kwargs): self._kernel_size = self.attrs['kernel_size'] -@_register_op_map(opns.DROP_OUT) @dataclass(init=False) class Dropout(symbol.Symbol): @@ -80,7 +88,6 @@ def __init__(self, name:str, **kwargs): self._rate = self.attrs['rate'] -@_register_op_map(opns.CLIP) @dataclass(init=False) class Clip(symbol.Symbol): @@ -88,12 +95,12 @@ class Clip(symbol.Symbol): @property def min(self) -> float: - default_val = None + default_val = np.nan return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val @property def max(self) -> float: - default_val = None + default_val = np.nan return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val def __init__(self, name:str, **kwargs): @@ -106,7 +113,6 @@ def __init__(self, name:str, **kwargs): self._max = self.attrs['max'] -@_register_op_map(opns.BATCH_NORM) @dataclass(init=False) class BatchNorm(symbol.Symbol): @@ -143,7 +149,6 @@ def __init__(self, name:str, **kwargs): self._center = self.attrs['center'] self._scale = self.attrs['scale'] -@_register_op_map(opns.DENSE) @dataclass(init=False) class Dense(symbol.Symbol): @@ -155,7 +160,6 @@ def __init__(self, name:str, **kwargs): self.attrs = kwargs.pop('attrs', {}) self.extra_attrs = {} -@_register_op_map(opns.TUPLE_GET_ITEM) @dataclass(init=False) class TupleGetItem(symbol.Symbol): @@ -174,7 +178,6 @@ def __init__(self, name:str, **kwargs): self._index = self.attrs['index'] -@_register_op_map(opns.MUL) @dataclass(init=False) class Multiply(symbol.Symbol): @@ -186,3 +189,10 @@ def __init__(self, name:str, **kwargs): self.attrs = kwargs.pop('attrs', {}) self.extra_attrs = {} +_register_op_map(opns.CONV2D)(Conv2D) +_register_op_map(opns.DROP_OUT)(Dropout) +_register_op_map(opns.CLIP)(Clip) +_register_op_map(opns.BATCH_NORM)(BatchNorm) +_register_op_map(opns.DENSE)(Dense) +_register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) +_register_op_map(opns.MUL)(Multiply) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index be2f823..ed9ac2a 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -1,8 +1,8 @@ """ MRT operator names """ import typing -MRT_OP_SET = set() -def _register_op_list(*op_names: typing.List[str]): +MRT_OP_SET: typing.Set[str] = set() +def _register_op_list(*op_names: str): for op_name in op_names: if op_name not in MRT_OP_SET: MRT_OP_SET.add(op_name) @@ -127,6 +127,6 @@ def _register_op_list(*op_names: typing.List[str]): _register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT) -def Opname2Funcname(op_name: str): +def Opname2Funcname(op_name: str) -> str: return op_name.replace('.', '_') #print('MRT_OP_SET:', MRT_OP_SET) diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index e27adcd..1832d87 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -462,7 +462,7 @@ def as_tuple(self) -> typing.Tuple[typing.List[str], Symbol]: @classmethod def from_tuple(cls, tuple_names, symbol): - assert symbol.is_op(TUPLE), symbol + assert symbol.is_op(opns.TUPLE), symbol mhs = cls(zip(tuple_names, symbol.args)) mhs.origin = symbol return mhs diff --git a/python/mrt/mir/symbolpass.py b/python/mrt/mir/symbolpass.py index 883af38..e9d65ee 100644 --- a/python/mrt/mir/symbolpass.py +++ b/python/mrt/mir/symbolpass.py @@ -64,9 +64,9 @@ def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: class FuseDividePass(SymbolPass): def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.DIVIDE: - argA = self.args[0] - argB = self.args[1] + if sym.op_name == opns.DIV: + argA = sym.args[0] + argB = sym.args[1] assert self.is_param(argB), f'NotParam: {argB}' # TODO: fixit #argB = argB.from_np_data(1. / argB.numpy()) diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index a36e16c..cf7e61c 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -27,18 +27,25 @@ def test_create_conv2d_op(): # padding: typing.Optional[typing.Tuple[int, int, int, int]] = (0,0,0,0) # create mrt op symbol, def func print('mrt Conv2D Op Class:', opclass.Conv2D) - conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_b',args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) - assert isinstance(conv2d_b, sx.Symbol), 'not!con2d_b symbol' - assert isinstance(conv2d_b, opclass.Conv2D), 'not!2 -con2d_b' + conv2d_a = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_a', args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) + assert isinstance(conv2d_a, sx.Symbol), 'conv2d_a isnot a symbol' + assert isinstance(conv2d_a, opclass.Conv2D), 'conv2d_a isnot a Conv2D' # attrs hint - assert conv2d_b.args != None - assert conv2d_b.attrs != None - assert conv2d_b.strides != None + assert conv2d_a.args != None + assert conv2d_a.attrs != None + assert conv2d_a.strides != None - print(f'Got {conv2d_b.name} strides: {conv2d_b.strides}') - print(f'Got {conv2d_b.name} padding: {conv2d_b.padding}') - print(f'Show {conv2d_b.name} {conv2d_b}') + print(f'Got {conv2d_a.name} strides: {conv2d_a.strides}') + print(f'Got {conv2d_a.name} padding: {conv2d_a.padding}') + print(f'Show {conv2d_a.name} {conv2d_a}') + + # test Conv2D clone mode + conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D](conv2d_a) + assert isinstance(conv2d_b, sx.Symbol), 'conv2d_b isnot a symbol' + assert isinstance(conv2d_b, opclass.Conv2D), 'conv2d_b isnot a Conv2D' + + assert conv2d_b.attrs == conv2d_a.attrs return True From 69c36b9450f7eba511424df8135b356bf87a1ae0 Mon Sep 17 00:00:00 2001 From: corlfj Date: Wed, 17 Sep 2025 16:53:04 +0800 Subject: [PATCH 3/7] [mir]: opclass compatible --- python/mrt/mir/opclass.py | 338 ++++++++++++------ python/mrt/mir/opns.py | 28 -- .../mrt/mir/{symbolpass.py => simple_pass.py} | 102 ++++-- python/mrt/mir/symbol.py | 2 + tests/mir/test.op_create.py | 100 +++++- ...est.symbol_pass.py => test.simple_pass.py} | 71 +++- 6 files changed, 443 insertions(+), 198 deletions(-) rename python/mrt/mir/{symbolpass.py => simple_pass.py} (50%) rename tests/mir/{test.symbol_pass.py => test.simple_pass.py} (50%) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index a8fb932..0ded6ba 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,198 +1,306 @@ import typing import numpy as np -from dataclasses import dataclass +from dataclasses import dataclass, fields + +from mrt.common.utils import N from . import opns from . import symbol +from .symbol import SelfSymbol -MRT_OP_MAP: typing.Dict[str, typing.Any] = {} - -#def _register_op_map_(op_name: str, clss:typing.Any=None): -# if len(op_name)>0 and clss!=None: -# if op_name not in MRT_OP_MAP: -# MRT_OP_MAP[op_name] = clss -# return MRT_OP_MAP +#SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") +MRT_OP_MAP: typing.Dict[str, SelfSymbol] = {} -def _register_op_map(op_name: str): #, clss:typing.Any=None): - def _wrapper(clss: typing.Any=None): - if len(op_name)>0 and clss!=None: +def _register_op_map(op_name: str): + def _wrapper(clss: SelfSymbol = None) -> SelfSymbol: + if len(op_name) > 0 and clss != None: if op_name not in MRT_OP_MAP: MRT_OP_MAP[op_name] = clss + else: + print(f'Warning: "{op_name}" Alreary Registered In MRT_OP_MAP, IsBeing Overrided!') + MRT_OP_MAP[op_name] = clss return clss return _wrapper + @dataclass(init=False) -class Conv2D(symbol.Symbol): +class Variable(symbol.Symbol): + op_name = opns.VAR + + def __init__(self, name=None, op_name=None, shape:typing.Tuple = (), dtype=None, extra_attrs=None): + op_name = op_name or opns.VAR + assert op_name == opns.VAR + super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs=extra_attrs or {}) + self.shape = shape # will also update extra_attrs + self.dtype = dtype # will also update extra_attrs + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['extra_attrs'][k] for k in data['extra_attrs'] if k in ['shape', 'dtype']} + try: + out = cls(**attrsdata, **basedata) + except Exception as e: + raise e + return out + +@dataclass(init=False) +class Conv2D(symbol.Symbol): op_name = opns.CONV2D @property def strides(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val + return self.attrs['strides'] if 'strides' in self.attrs else default_val @property def padding(self) -> typing.Tuple[int, int, int, int]: default_val = (0,0,0,0) - return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val + return self.attrs['padding'] if 'padding' in self.attrs else default_val + + @property + def groups(self) -> int: + default_val = 1 + return self.attrs['groups'] if 'groups' in self.attrs else default_val @property def dilation(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._dilation if self._dilation else self.attrs['dilation'] if 'dilation' in self.attrs else default_val + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val @property def kernel_size(self) -> typing.Tuple[int, int]: default_val = (3,3) - return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val - - def __init__(self, name_or_inst: typing.Union[str, symbol.Symbol], **kwargs): - assert isinstance(name_or_inst, str) or isinstance(name_or_inst, symbol.Symbol) - if isinstance(name_or_inst, str): - self.name = name_or_inst - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - else: - # clone mode - self.name = name_or_inst.name - self.args = [a for a in name_or_inst.args] - self.attrs = {k: v for k, v in name_or_inst.attrs.items()} - self.extra_attrs = {k: v for k, v in name_or_inst.extra_attrs.items()} - - # TODO: what if strides not in attrs? - if 'strides' in self.attrs: - self._strides = self.attrs['strides'] - if 'padding' in self.attrs: - self._padding = self.attrs['padding'] - if 'dilation' in self.attrs: - self._dilation = self.attrs['dilation'] - if 'kernel_size' in self.attrs: - self._kernel_size = self.attrs['kernel_size'] - + return self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + + + # Follows (*args, name, **attrs) + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_size=(3,3), extra_attrs=None): + op_name = op_name or opns.CONV2D + assert op_name == opns.CONV2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) + + + # Copy from other instance of same opclass, must have specific attrs (or with default value) + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['strides', 'padding', 'groups', 'dilation', 'kernel_size']} + try: + out = cls(data['args'][0], data['args'][1], **attrsdata, **basedata) + except Exception as e: + raise e + return out @dataclass(init=False) class Dropout(symbol.Symbol): - op_name = opns.DROP_OUT @property def rate(self) -> float: default_val = 0.0 - return self._rate if self._rate else self.attrs['rate'] if 'rate' in self.attrs else default_val + return self.attrs['rate'] if 'rate' in self.attrs else default_val - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._rate = self.attrs['rate'] + def __init__(self, X, name=None, op_name=None, rate:float = 0, extra_attrs=None): + op_name = op_name or opns.DROP_OUT + assert op_name == opns.DROP_OUT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'rate': rate}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {'rate': data['attrs']['rate']} + try: + out = cls(data['args'][0], **attrsdata, **basedata) + except Exception as e: + raise e + return out @dataclass(init=False) class Clip(symbol.Symbol): - op_name = opns.CLIP @property def min(self) -> float: default_val = np.nan - return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val + return self.attrs['min'] if 'min' in self.attrs else default_val @property def max(self) -> float: default_val = np.nan - return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._min = self.attrs['min'] - self._max = self.attrs['max'] + return self.attrs['max'] if 'max' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): + op_name = op_name or opns.CLIP + assert op_name == opns.CLIP + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min': min_, 'max': max_}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {'min': data['attrs']['min'], 'max': data['attrs']['max']} + try: + out = cls(data['args'][0], **attrsdata, **basedata) + except Exception as e: + raise e + return out @dataclass(init=False) class BatchNorm(symbol.Symbol): - op_name = opns.BATCH_NORM @property - def axis(self) -> float: + def axis(self) -> int: default_val = 1 - return self._axis if self._axis else self.attrs['axis'] if 'axis' in self.attrs else default_val + return self.attrs['axis'] if 'axis' in self.attrs else default_val @property def epsilon(self) -> float: default_val = 1e-5 - return self._epsilon if self._epsilon else self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val + return self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val @property - def center(self) -> float: + def center(self) -> bool: default_val = True - return self._center if self._center else self.attrs['center'] if 'center' in self.attrs else default_val + return self.attrs['center'] if 'center' in self.attrs else default_val @property - def scale(self) -> float: + def scale(self) -> bool: default_val = True - return self._scale if self._scale else self.attrs['scale'] if 'scale' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._axis = self.attrs['axis'] - self._epsilon = self.attrs['epsilon'] - self._center = self.attrs['center'] - self._scale = self.attrs['scale'] + return self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, center:bool = True, scale:bool = True, extra_attrs=None): + op_name = op_name or opns.BATCH_NORM + assert op_name == opns.BATCH_NORM + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['axis', 'epsilon', 'center', 'scale']} + try: + out = cls(*data['args'], **attrsdata, **basedata) + except Exception as e: + raise e + return out -@dataclass(init=False) -class Dense(symbol.Symbol): - - op_name = opns.DENSE - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} @dataclass(init=False) class TupleGetItem(symbol.Symbol): - op_name = opns.TUPLE_GET_ITEM @property def index(self) -> float: default_val = 0 - return self._index if self._index else self.attrs['index'] if 'index' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._index = self.attrs['index'] - -@dataclass(init=False) -class Multiply(symbol.Symbol): - - op_name = opns.MUL - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - + return self.attrs['index'] if 'index' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): + op_name = op_name or opns.TUPLE_GET_ITEM + assert op_name == opns.TUPLE_GET_ITEM + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'index': index}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {'index': data['attrs']['index']} + try: + out = cls(data['args'][0], **attrsdata, **basedata) + except Exception as e: + raise e + return out + + +_register_op_map(opns.VAR)(Variable) _register_op_map(opns.CONV2D)(Conv2D) _register_op_map(opns.DROP_OUT)(Dropout) _register_op_map(opns.CLIP)(Clip) _register_op_map(opns.BATCH_NORM)(BatchNorm) -_register_op_map(opns.DENSE)(Dense) _register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) -_register_op_map(opns.MUL)(Multiply) + +# Add default register Class for MRT OP Not Implemented! +_register_op_map(opns.MUL)(symbol.Symbol) +_register_op_map(opns.DENSE)(symbol.Symbol) +_register_op_map(opns.RELU)(symbol.Symbol) +_register_op_map(opns.HARDTANH)(symbol.Symbol) +_register_op_map(opns.SILU)(symbol.Symbol) +_register_op_map(opns.LEAKY_RELU)(symbol.Symbol) +_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(symbol.Symbol) +_register_op_map(opns.AVG_POOL2D)(symbol.Symbol) +_register_op_map(opns.MAX_POOL2D)(symbol.Symbol) +_register_op_map(opns.SOFTMAX)(symbol.Symbol) +_register_op_map(opns.LOG_SOFTMAX)(symbol.Symbol) +_register_op_map(opns.EXP)(symbol.Symbol) +_register_op_map(opns.SIGMOID)(symbol.Symbol) +_register_op_map(opns.SUM)(symbol.Symbol) +_register_op_map(opns.MEAN)(symbol.Symbol) +_register_op_map(opns.MAX_AXIS)(symbol.Symbol) +_register_op_map(opns.MAXIMUM)(symbol.Symbol) +_register_op_map(opns.MINIMUM)(symbol.Symbol) +_register_op_map(opns.TUPLE)(symbol.Symbol) +_register_op_map(opns.REPEAT)(symbol.Symbol) +_register_op_map(opns.SQUEEZE)(symbol.Symbol) +_register_op_map(opns.FLATTEN)(symbol.Symbol) +_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) +_register_op_map(opns.RESHAPE)(symbol.Symbol) +_register_op_map(opns.CONCAT)(symbol.Symbol) +_register_op_map(opns.SPLIT)(symbol.Symbol) +_register_op_map(opns.TRANSPOSE)(symbol.Symbol) +_register_op_map(opns.BROADCAST_TO)(symbol.Symbol) +_register_op_map(opns.EXPAND_DIMS)(symbol.Symbol) +_register_op_map(opns.TILE)(symbol.Symbol) +_register_op_map(opns.WHERE)(symbol.Symbol) +_register_op_map(opns.GREATER)(symbol.Symbol) +_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) +_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) +_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) +_register_op_map(opns.NON_MAX_SUPRESSION)(symbol.Symbol) +_register_op_map(opns.CEIL)(symbol.Symbol) +_register_op_map(opns.RIGHT_SHIFT)(symbol.Symbol) +_register_op_map(opns.AS_TYPE)(symbol.Symbol) +_register_op_map(opns.ADV_INDEX)(symbol.Symbol) +_register_op_map(opns.CALL_TIR)(symbol.Symbol) +_register_op_map(opns.CALL_DPS_PACKED)(symbol.Symbol) +_register_op_map(opns.ADD)(symbol.Symbol) +_register_op_map(opns.SUB)(symbol.Symbol) +_register_op_map(opns.MATMUL)(symbol.Symbol) +_register_op_map(opns.DIV)(symbol.Symbol) +_register_op_map(opns.NEGATIVE)(symbol.Symbol) +_register_op_map(opns.ABS)(symbol.Symbol) +_register_op_map(opns.LOG)(symbol.Symbol) +_register_op_map(opns.SQRT)(symbol.Symbol) +_register_op_map(opns.POW)(symbol.Symbol) +_register_op_map(opns.PASS)(symbol.Symbol) +_register_op_map(opns.ARANGE)(symbol.Symbol) +_register_op_map(opns.ZEROS_LIKE)(symbol.Symbol) +_register_op_map(opns.ONES_LIKE)(symbol.Symbol) +_register_op_map(opns.IF)(symbol.Symbol) +_register_op_map(opns.ARGWHERE)(symbol.Symbol) +_register_op_map(opns.REQUANT)(symbol.Symbol) +_register_op_map(opns.PCLIP)(symbol.Symbol) +_register_op_map(opns.RS_PCLIP)(symbol.Symbol) +_register_op_map(opns.LUT)(symbol.Symbol) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index ed9ac2a..31da253 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -1,14 +1,6 @@ """ MRT operator names """ -import typing - -MRT_OP_SET: typing.Set[str] = set() -def _register_op_list(*op_names: str): - for op_name in op_names: - if op_name not in MRT_OP_SET: - MRT_OP_SET.add(op_name) VAR = "var" -_register_op_list(VAR) DROP_OUT = "nn.dropout" CONV2D = "nn.conv2d" @@ -22,29 +14,22 @@ def _register_op_list(*op_names: str): ADAPTIVE_AVG_POOL2D = "nn.adaptive_avg_pool2d" AVG_POOL2D = "nn.avg_pool2d" MAX_POOL2D = "nn.max_pool2d" -_register_op_list(DROP_OUT, CONV2D, DENSE, BATCH_NORM, RELU, - HARDTANH, SILU, LEAKY_RELU, ADAPTIVE_AVG_POOL2D, - AVG_POOL2D, MAX_POOL2D) SOFTMAX = "nn.softmax" LOG_SOFTMAX = "nn.log_softmax" -_register_op_list(SOFTMAX, LOG_SOFTMAX) EXP = "exp" SIGMOID = "sigmoid" -_register_op_list(EXP, SIGMOID) SUM = "sum" MEAN = "mean" MAX_AXIS = "max" MAXIMUM = "maximum" MINIMUM = "minimum" -_register_op_list(SUM, MEAN, MAX_AXIS, MAXIMUM, MINIMUM) # =========== NON-CALC ops =============== TUPLE = "Tuple" TUPLE_GET_ITEM = "TupleGetItem" -_register_op_list(TUPLE, TUPLE_GET_ITEM) REPEAT = "repeat" SQUEEZE = "squeeze" @@ -55,12 +40,9 @@ def _register_op_list(*op_names: str): SPLIT = "split" TRANSPOSE = "transpose" BROADCAST_TO = "broadcast_to" -_register_op_list(REPEAT, SQUEEZE, FLATTEN, BATCH_FLATTEN, RESHAPE, - CONCAT, SPLIT, TRANSPOSE, BROADCAST_TO, ) EXPAND_DIMS = "expand_dims" TILE = "tile" -_register_op_list(EXPAND_DIMS, TILE) WHERE = "where" GREATER = "greater" @@ -68,7 +50,6 @@ def _register_op_list(*op_names: str): SLICE_LIKE = "slice_like" GET_VALID_COUNT = "vision.get_valid_counts" NON_MAX_SUPRESSION = "vision.non_max_suppression" -_register_op_list(WHERE, GREATER, STRIDED_SLICE, SLICE_LIKE, GET_VALID_COUNT, NON_MAX_SUPRESSION) # relax clip attrs from a_min/a_max to min/max CLIP = "clip" @@ -77,14 +58,11 @@ def _register_op_list(*op_names: str): # relax support astype instead of cast AS_TYPE = "astype" # CAST = "cast" -_register_op_list(CLIP, CEIL, RIGHT_SHIFT, AS_TYPE) ADV_INDEX = "adv_index" -_register_op_list(ADV_INDEX) CALL_TIR = "call_tir" CALL_DPS_PACKED = "call_dps_packed" -_register_op_list(CALL_TIR, CALL_DPS_PACKED) # ======= binary ops ============= @@ -93,7 +71,6 @@ def _register_op_list(*op_names: str): MUL = "multiply" MATMUL = "matmul" DIV = "divide" -_register_op_list(ADD, SUB, MUL, MATMUL, DIV) # ======= unary ops ============== @@ -104,17 +81,14 @@ def _register_op_list(*op_names: str): POW = "pow" PASS = "pass" -_register_op_list(NEGATIVE, ABS, LOG, SQRT, POW, PASS) # ======= auto generate op ========= ARANGE = "arange" ZEROS_LIKE = "zeros_like" ONES_LIKE = "ones_like" -_register_op_list(ARANGE, ZEROS_LIKE, ONES_LIKE) # ======= control flow op =========== IF = "if" ARGWHERE = "argwhere" -_register_op_list(IF, ARGWHERE) # ======= mrt requant op ========== REQUANT = "mrt.requant" @@ -124,9 +98,7 @@ def _register_op_list(*op_names: str): """ right shift precision clip """ LUT = "mrt.lut" """ look up table, equals adv_index in tvm """ -_register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT) def Opname2Funcname(op_name: str) -> str: return op_name.replace('.', '_') -#print('MRT_OP_SET:', MRT_OP_SET) diff --git a/python/mrt/mir/symbolpass.py b/python/mrt/mir/simple_pass.py similarity index 50% rename from python/mrt/mir/symbolpass.py rename to python/mrt/mir/simple_pass.py index e9d65ee..2fc362e 100644 --- a/python/mrt/mir/symbolpass.py +++ b/python/mrt/mir/simple_pass.py @@ -14,18 +14,13 @@ # mrt op visits -class SymbolPass: +class SimplePass: symbol: _symbol.Symbol - params: ParametersT - - def __init__(self, symbol: _symbol.Symbol, params: ParametersT): - self.symbol = symbol - self.params = params - def is_param(self, symbol: _symbol.Symbol) -> bool: - return symbol.op_name == opns.VAR and symbol.name in self.params + def __init__(self, symbol: _symbol.Symbol): + self.symbol = symbol - def visit(self) -> _symbol.Symbol: + def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] = None) -> _symbol.Symbol: env: typing.Dict[str, _symbol.Symbol] = {} for sym in _symbol.sym2list(self.symbol): assert sym.name not in env, f'{sym.name} NotIn env!' @@ -33,11 +28,8 @@ def visit(self) -> _symbol.Symbol: # Updating args as passed symbol in env_dict sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) assert isinstance(sym, _symbol.Symbol), sym - - if sym.op_name == opns.DROP_OUT: - #print('ddrroopped_out', getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym) - pass - out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym + out = custom_func(sym) if custom_func else getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) + out = out or sym assert isinstance(out, _symbol.Symbol), out env[sym.name] = out return env[self.symbol.name] @@ -46,15 +38,31 @@ def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: return op -# register mrt op default_visit -for op_name in opns.MRT_OP_SET: +# mrt op visits with params, variables +class InferPass(SimplePass): + params: ParametersT + + def is_param(self, symbol: _symbol.Symbol) -> bool: + return symbol.op_name == opns.VAR and symbol.name in self.params + + def get_param(self, symbol: _symbol.Symbol) -> OpNumpyT: + assert self.is_param(symbol) + return self.params[symbol.name] if self.is_param(symbol) else [] + + def __init__(self, symbol: _symbol.Symbol, params: ParametersT): + self.symbol = symbol + self.params = params + + +# Register MRT all op's default_visit_op function +for op_name in opclass.MRT_OP_MAP.keys(): funcSuffix = opns.Opname2Funcname(op_name) - setattr(SymbolPass, f"visit_{funcSuffix}", SymbolPass._default_visit_op) - #print(f"visit_, {op_name} => {funcSuffix}", getattr(SymbolPass, f"visit_{funcSuffix}")) + setattr(SimplePass, f"visit_{funcSuffix}", SimplePass._default_visit_op) + #print(f"visit_, {op_name} => {funcSuffix}", getattr(SimplePass, f"visit_{funcSuffix}")) -# mrt symbol pass -class FuseDropoutPass(SymbolPass): +# mrt symbol simple pass +class FuseDropoutPass(SimplePass): def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: # make sure op fit again if sym.op_name == opns.DROP_OUT: @@ -62,7 +70,48 @@ def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: return sym -class FuseDividePass(SymbolPass): +class FuseTupleGetItemPass(SimplePass): + def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.TUPLE_GET_ITEM: + return sym + sym_ : opclass.TupleGetItem = sym + assert sym_.index == 0 + return sym_.args[0] + return sym + + +class FuseBatchNormPass(InferPass): + def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) + return sym + return sym + + +class FuseSoftmaxPass(SimplePass): + def visit_nn_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.SOFTMAX: + return self.args[0] + return sym + + def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.LOG_SOFTMAX: + return self.args[0] + return sym + + +class FuseMeanPass(SimplePass): + def visit_mean(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.MEAN: + return sym + return sym + + +class FuseDividePass(InferPass): def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.DIV: argA = sym.args[0] @@ -70,15 +119,6 @@ def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: assert self.is_param(argB), f'NotParam: {argB}' # TODO: fixit #argB = argB.from_np_data(1. / argB.numpy()) - return opclass.Multiply(sym.name, {'args':[argA, argB]}) - return sym - - -class FuseTupleGetItemPass(SymbolPass): - def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.TUPLE_GET_ITEM: - sym_ : opclass.TupleGetItem = sym - assert sym_.index == 0 - return sym_.args[0] + return opclass.MRT_OP_MAP[opns.MUL](argA, argB) return sym diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 1832d87..c1e6bd7 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -19,6 +19,8 @@ "filter_operators", ] +SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") + def _format_printer(data): if isinstance(data, dict): data = ["{}={}".format(k, _format_printer(v)) \ diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index cf7e61c..eab2d90 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -22,12 +22,14 @@ def test_create_conv2d_op(): - #class CONV2D(Symbol): - # strides: typing.Tuple[int, int] = (1,1) - # padding: typing.Optional[typing.Tuple[int, int, int, int]] = (0,0,0,0) - # create mrt op symbol, def func - print('mrt Conv2D Op Class:', opclass.Conv2D) - conv2d_a = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_a', args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) + + X = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") + W = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + assert [shp for shp in X.shape] == [shp for shp in (1, 3, 224, 224)], f'Wrong X shape {X.shape}' + assert X.dtype == "float", f'Wrong X dtype {X.dtype}' + + # Symbol Init using opclass OP + conv2d_a = opclass.Conv2D(X, W, name='conv2d_a', strides=(2,2)) assert isinstance(conv2d_a, sx.Symbol), 'conv2d_a isnot a symbol' assert isinstance(conv2d_a, opclass.Conv2D), 'conv2d_a isnot a Conv2D' @@ -41,27 +43,87 @@ def test_create_conv2d_op(): print(f'Show {conv2d_a.name} {conv2d_a}') # test Conv2D clone mode - conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D](conv2d_a) + conv2d_b = conv2d_a.copy() assert isinstance(conv2d_b, sx.Symbol), 'conv2d_b isnot a symbol' assert isinstance(conv2d_b, opclass.Conv2D), 'conv2d_b isnot a Conv2D' - assert conv2d_b.attrs == conv2d_a.attrs + assert conv2d_b.attrs == conv2d_a.attrs, f'a: {conv2d_b.attrs} != b: {conv2d_a.attrs}' + + # test Dict to Find Class and Init + conv2d_c = opclass.MRT_OP_MAP[opns.CONV2D](X, W, strides=(2,2)) + assert isinstance(conv2d_c, opclass.Conv2D), 'conv2d_c isnot a Conv2D' + + # test Variable clone mode + X1 = X.copy() + assert X1.shape == X.shape + assert X1.dtype == X.dtype + + # test: Symbol Compatible Mode + args = [X1, W] + attrs = {'strides':(3,3)} + + # Symbol Compatible Init + conv2d_d = opclass.Conv2D(*args, name='conv2d_d', **attrs) + conv2d_e = opclass.Conv2D(*args, **attrs) + assert isinstance(conv2d_d, opclass.Conv2D), 'conv2d_d isnot a Conv2D' + assert isinstance(conv2d_e, opclass.Conv2D), 'conv2d_e isnot a Conv2D' + + return True + + +def test_create_symbol_graph(): + X0 = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") + W0 = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + conv2d_a = opclass.Conv2D(X0, W0, name='conv2d_a', strides=(1,1)) + + W1 = opclass.Variable(shape=(16, 3, 12, 12,), dtype="float") + conv2d_b = opclass.Conv2D(conv2d_a, W1, name='conv2d_b', strides=(1,1)) + symlist = sx.sym2list(conv2d_b) + + assert symlist[0] == X0 + assert symlist[1] == W0 + + for id_ in range(len(symlist)): + print(id_, symlist[id_]) + return True -# TODO: -#def test_create_symbol_graph(): +def test_create_batch_norm_op(): + X = opclass.Variable(name="x", shape=(1, 32, 128, 128,), dtype="float") + Gamma = opclass.Variable(name="gamma", shape=(32,), dtype="float") + Beta = opclass.Variable(name="beta", shape=(32,), dtype="float") + Mean = opclass.Variable(name="mean", shape=(32,), dtype="float") + Var = opclass.Variable(name="var", shape=(32,), dtype="float") + batch_norm_a = opclass.BatchNorm(X, Gamma, Beta, Mean, Var, axis=1, epsilon=1e-4) + + # attrs hint + assert batch_norm_a.args != None + assert batch_norm_a.attrs != None + assert batch_norm_a.axis != 0 + + # test clone mode + batch_norm_b = batch_norm_a.copy() + assert isinstance(batch_norm_b , opclass.BatchNorm) + + assert batch_norm_a.attrs == batch_norm_b.attrs, f'a: {batch_norm_a.attrs} != b: {batch_norm_b.attrs}' + assert len(batch_norm_a.args) == len(batch_norm_b.args), f'a: {len(batch_norm_a.args)} != b: {len(batch_norm_b.args)}' + + return True + if __name__ == "__main__": - print('MRT_OP_SET as:', opns.MRT_OP_SET) - assert len(opns.MRT_OP_SET) > 0 + print('MRT_OP_SET as:', opclass.MRT_OP_MAP.keys()) + assert len(opclass.MRT_OP_MAP.keys()) > 0 - print('MRT_OP_MAP Class as:', opclass.MRT_OP_MAP) - assert len(opclass.MRT_OP_MAP) > 0 assert opns.CONV2D in opclass.MRT_OP_MAP - - rltflag = test_create_conv2d_op() - print("\n" + "="*60 + "\n") - print('Passed Test!' if rltflag else 'Test Failed!') - print("\n" + "="*60 + "\n") + print('MRT_OP_MAP Conv2D Class as:', opclass.MRT_OP_MAP[opns.CONV2D]) + + test_id = 0 + for func_ in [test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op]: + rltflag = func_() + test_id += 1 + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id}!' if rltflag else f'Test{test_id} Failed!') + print("\n" + "="*60 + "\n") diff --git a/tests/mir/test.symbol_pass.py b/tests/mir/test.simple_pass.py similarity index 50% rename from tests/mir/test.symbol_pass.py rename to tests/mir/test.simple_pass.py index 7d109ef..9d20ee0 100644 --- a/tests/mir/test.symbol_pass.py +++ b/tests/mir/test.simple_pass.py @@ -19,7 +19,7 @@ from mrt.mir import helper, symbol as sx from mrt.mir import opns from mrt.mir import opclass -from mrt.mir import symbolpass +from mrt.mir import simple_pass def _get_alexnet_model(): """Get Alexnet MRT Model""" @@ -41,7 +41,7 @@ def _get_alexnet_model(): mrt_graph, mrt_params = pytorch_to_mrt(ep) return mrt_graph, mrt_params -def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): +def test_SimplePass_FuseDropout(mrt_graph, mrt_params): symbol = mrt_graph['main'] #print(symbol) @@ -54,7 +54,7 @@ def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): assert dropout_op_cnt>0, f'original model dropout op cnt {dropout_op_cnt} == zero!' # init FuseDropout Passer and execute visit - tfs : symbolpass.FuseDropoutPass = symbolpass.FuseDropoutPass(symbol, {}) + tfs : simple_pass.FuseDropoutPass = simple_pass.FuseDropoutPass(symbol) #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) symbol_passed = tfs.visit() @@ -74,15 +74,76 @@ def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): return True + +def test_SimplePass_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + conv2d_name_list = [] + def _filter_op(sym: sx.Symbol) -> sx.Symbol: + if sym.op_name == opns.CONV2D: + conv2d_name_list.append(sym.name) + return sym + + symbol_passed = tfs.visit(_filter_op) + + print('\n=== After CustomFunc Pass ===') + assert len(conv2d_name_list) > 0 + print(conv2d_name_list) + rlts = sx.sym2list(symbol_passed) + + return True + + +def test_SimplePass_FuseDropout_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before FuseDropout CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt > 0, f'ori model dropout op cnt {dropout_op_cnt} == zero!' + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + def _nn_dropout(sym: sx.Symbol) -> sx.Symbol: + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + symbol_passed = tfs.visit(_nn_dropout) + + print('\n=== After FuseDropout CustomFunc Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af == 0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + return True + + if __name__ == "__main__": print("=== Testing SymbolPass ===") mrt_graph, mrt_params = _get_alexnet_model() print("Testing FuseDropoutPass for Model AlexNet") - rltflag = test_SymbolPass_FuseDropout(mrt_graph, mrt_params) + rltflag = test_SimplePass_FuseDropout(mrt_graph, mrt_params) + print("\n" + "="*60 + "\n") + print('Passed Test1!' if rltflag else 'Test1 Failed!') + print("\n" + "="*60 + "\n") + + rltflag = test_SimplePass_CustomFunc(mrt_graph) + print("\n" + "="*60 + "\n") + print('Passed Test2!' if rltflag else 'Test2 Failed!') + print("\n" + "="*60 + "\n") + print("Testing FuseDropout CustomFunc for Model AlexNet") + rltflag = test_SimplePass_FuseDropout_CustomFunc(mrt_graph) print("\n" + "="*60 + "\n") - print('Passed Test!' if rltflag else 'Test Failed!') + print('Passed Test3!' if rltflag else 'Test3 Failed!') print("\n" + "="*60 + "\n") From 74177d252500ba9e79ed6b560514ec73980a488e Mon Sep 17 00:00:00 2001 From: corlfj Date: Mon, 22 Sep 2025 10:26:14 +0800 Subject: [PATCH 4/7] [mir]: opclass opfunc, more op --- python/mrt/mir/opclass.py | 904 +++++++++++++++++++++++++++++------- tests/mir/test.op_create.py | 20 +- 2 files changed, 758 insertions(+), 166 deletions(-) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 0ded6ba..2e4cdf9 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,6 +1,6 @@ import typing import numpy as np -from dataclasses import dataclass, fields +from dataclasses import dataclass from mrt.common.utils import N from . import opns @@ -8,10 +8,14 @@ from .symbol import SelfSymbol #SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") -MRT_OP_MAP: typing.Dict[str, SelfSymbol] = {} + +SymbolCreator = typing.Union[typing.Callable[[typing.Any, typing.Any], typing.Type[symbol.Symbol]], SelfSymbol] +#SymbolCreator = typing.Union[typing.Callable[[...], symbol.Symbol], SelfSymbol] + +MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} def _register_op_map(op_name: str): - def _wrapper(clss: SelfSymbol = None) -> SelfSymbol: + def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: if len(op_name) > 0 and clss != None: if op_name not in MRT_OP_MAP: MRT_OP_MAP[op_name] = clss @@ -22,30 +26,43 @@ def _wrapper(clss: SelfSymbol = None) -> SelfSymbol: return _wrapper -@dataclass(init=False) -class Variable(symbol.Symbol): - op_name = opns.VAR - - def __init__(self, name=None, op_name=None, shape:typing.Tuple = (), dtype=None, extra_attrs=None): - op_name = op_name or opns.VAR - assert op_name == opns.VAR - super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs=extra_attrs or {}) - self.shape = shape # will also update extra_attrs - self.dtype = dtype # will also update extra_attrs - - @classmethod - def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {k: data['extra_attrs'][k] for k in data['extra_attrs'] if k in ['shape', 'dtype']} - try: - out = cls(**attrsdata, **basedata) - except Exception as e: - raise e - return out +# OPs from external (not in MRT op), using custom op_name with default op_func +#y = extern_opfunc("tanh")(X) +def extern_opfunc(op_name: str): + def op_func(*args, **attrs): + return symbol.Symbol(*args, op_name=op_name, **attrs) + return op_func + + +def _from_dict_attrs(cls, d: dict, attr_keys:typing.List[str]=[], **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in attr_keys} + try: + out = cls(*data['args'], **attrsdata, **basedata) + except Exception as e: + raise e + return out + +# OPs without attrs, just register function (funcName should be lower case) +def var(name=None, op_name=None, shape=(), dtype=float) -> symbol.Symbol: + op_name = op_name or opns.VAR + assert op_name == opns.VAR + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs={'shape': shape or (), 'dtype': dtype or float}) + +#def _return_func_single_arg(op_name: op_name): +def relu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.RELU + assert op_name == opns.RELU + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def silu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SILU + assert op_name == opns.SILU + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) @dataclass(init=False) @@ -84,49 +101,27 @@ def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0, assert op_name == opns.CONV2D super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) - - # Copy from other instance of same opclass, must have specific attrs (or with default value) @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['strides', 'padding', 'groups', 'dilation', 'kernel_size']} - try: - out = cls(data['args'][0], data['args'][1], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_size'], **kwargs) @dataclass(init=False) class Dropout(symbol.Symbol): op_name = opns.DROP_OUT @property - def rate(self) -> float: - default_val = 0.0 - return self.attrs['rate'] if 'rate' in self.attrs else default_val + def p(self) -> float: + default_val = 0.5 + return self.attrs['p'] if 'p' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, rate:float = 0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): op_name = op_name or opns.DROP_OUT assert op_name == opns.DROP_OUT - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'rate': rate}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'p': p}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {'rate': data['attrs']['rate']} - try: - out = cls(data['args'][0], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['p'], **kwargs) @dataclass(init=False) class Clip(symbol.Symbol): @@ -149,17 +144,7 @@ def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {'min': data['attrs']['min'], 'max': data['attrs']['max']} - try: - out = cls(data['args'][0], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) @dataclass(init=False) @@ -177,33 +162,18 @@ def epsilon(self) -> float: return self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val @property - def center(self) -> bool: - default_val = True + def momentum(self) -> float: + default_val = 0.1 return self.attrs['center'] if 'center' in self.attrs else default_val - @property - def scale(self) -> bool: - default_val = True - return self.attrs['scale'] if 'scale' in self.attrs else default_val - - def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, center:bool = True, scale:bool = True, extra_attrs=None): + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): op_name = op_name or opns.BATCH_NORM assert op_name == opns.BATCH_NORM - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['axis', 'epsilon', 'center', 'scale']} - try: - out = cls(*data['args'], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum'], **kwargs) @dataclass(init=False) @@ -222,85 +192,707 @@ def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {'index': data['attrs']['index']} - try: - out = cls(data['args'][0], **attrsdata, **basedata) - except Exception as e: - raise e - return out - - -_register_op_map(opns.VAR)(Variable) + return _from_dict_attrs(cls, d, ['index'], **kwargs) + + +@dataclass(init=False) +class LeakyRelu(symbol.Symbol): + op_name = opns.LEAKY_RELU + + @property + def negative_slope(self) -> float: + default_val = 1e-2 + return self.attrs['negative_slope'] if 'negative_slope' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): + op_name = op_name or opns.LEAKY_RELU + assert op_name == opns.LEAKY_RELU + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'negative_slope': negative_slope}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) + + +def dense(X, W, B, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.DENSE + assert op_name == opns.DENSE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, W, B], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Hardtanh(symbol.Symbol): + op_name = opns.HARDTANH + + @property + def min_val(self) -> float: + default_val = -1.0 + return self.attrs['min_val'] if 'min_val' in self.attrs else default_val + + @property + def max_val(self) -> float: + default_val = 1.0 + return self.attrs['max_val'] if 'max_val' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + op_name = op_name or opns.HARDTANH + assert op_name == opns.HARDTANH + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min_val': min_val, 'max_val':max_val}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) + + +@dataclass(init=False) +class AdaptiveAvgPool2D(symbol.Symbol): + op_name = opns.ADAPTIVE_AVG_POOL2D + + @property + def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: + default_val = 0 + return self.attrs['output_size'] if 'output_size' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + op_name = op_name or opns.ADAPTIVE_AVG_POOL2D + assert op_name == opns.ADAPTIVE_AVG_POOL2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'output_size': output_size}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['output_size'], **kwargs) + +@dataclass(init=False) +class AvgPool2D(symbol.Symbol): + op_name = opns.AVG_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + default_val = (2, 2) + return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + @property + def strides(self): + default_val = None + return self.attrs['strides'] if 'strides' in self.attrs else default_val + @property + def padding(self) -> int: + default_val = 0 + return self.attrs['padding'] if 'padding' in self.attrs else default_val + @property + def ceil_mode(self) -> bool: + default_val = False + return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val + @property + def layout(self) -> str: + default_val = 'NCHW' + return self.attrs['layout'] if 'layout' in self.attrs else default_val + @property + def count_include_pad(self) -> bool: + default_val = True + return self.attrs['count_include_pad'] if 'count_include_pad' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, pool_size=(2,2), strides=None, padding=0, ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + op_name = op_name or opns.AVG_POOL2D + assert op_name == opns.AVG_POOL2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + +@dataclass(init=False) +class MaxPool2D(symbol.Symbol): + op_name = opns.MAX_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + default_val = (2, 2) + return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + @property + def layout(self) -> str: + default_val = 'NCHW' + return self.attrs['layout'] if 'layout' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, pool_size=(2,2), layout='NCHW', extra_attrs=None): + op_name = op_name or opns.MAX_POOL2D + assert op_name == opns.MAX_POOL2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'layout':layout}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'layout'], **kwargs) + + +@dataclass(init=False) +class Softmax(symbol.Symbol): + op_name = opns.SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + default_val = None + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.SOFTMAX + assert op_name == opns.SOFTMAX + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + + +@dataclass(init=False) +class LogSoftmax(symbol.Symbol): + op_name = opns.LOG_SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + default_val = None + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.LOG_SOFTMAX + assert op_name == opns.LOG_SOFTMAX + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + + +def exp(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.EXP + assert op_name == opns.EXP + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def sigmoid(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SIGMOID + assert op_name == opns.SIGMOID + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Sum(symbol.Symbol): + op_name = opns.SUM + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.SUM + assert op_name == opns.SUM + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + + +@dataclass(init=False) +class Mean(symbol.Symbol): + op_name = opns.MEAN + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.MEAN + assert op_name == opns.MEAN + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + + +@dataclass(init=False) +class MaxAxis(symbol.Symbol): + op_name = opns.MAX_AXIS + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.MAX_AXIS + assert op_name == opns.MAX_AXIS + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def maximum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MAXIMUM + assert op_name == opns.MAXIMUM + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def minimum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MINIMUM + assert op_name == opns.MINIMUM + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def repeat(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.REPEAT + assert op_name == opns.REPEAT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Squeeze(symbol.Symbol): + op_name = opns.SQUEEZE + + @property + def dim(self) -> typing.Optional[int]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, extra_attrs=None): + op_name = op_name or opns.SQUEEZE + assert op_name == opns.SQUEEZE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim'], **kwargs) + +@dataclass(init=False) +class Flatten(symbol.Symbol): + op_name = opns.FLATTEN + + @property + def start_dim(self) -> int: + default_val = 0 + return self.attrs['start_dim'] if 'start_dim' in self.attrs else default_val + + @property + def end_dim(self) -> int: + default_val = -1 + return self.attrs['end_dim'] if 'end_dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): + op_name = op_name or opns.FLATTEN + assert op_name == opns.FLATTEN + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'start_dim': start_dim, 'end_dim':end_dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) + + +@dataclass(init=False) +class Reshape(symbol.Symbol): + op_name = opns.RESHAPE + + @property + def newshape(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.RESHAPE + assert op_name == opns.RESHAPE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + + +@dataclass(init=False) +class Concat(symbol.Symbol): + op_name = opns.CONCAT + + @property + def axis(self) -> int: + default_val = 0 + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.CONCAT + assert op_name == opns.CONCAT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis': axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + + +@dataclass(init=False) +class Split(symbol.Symbol): + op_name = opns.SPLIT + + @property + def split_size(self) -> typing.List[int]: + default_val = [] + return self.attrs['split_size'] if 'split_size' in self.attrs else default_val + + @property + def dim(self) -> int: + default_val = 0 + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + op_name = op_name or opns.SPLIT + assert op_name == opns.SPLIT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'split_size': split_size, 'dim': dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) + + +@dataclass(init=False) +class Transpose(symbol.Symbol): + op_name = opns.TRANSPOSE + + @property + def dim0(self) -> int: + default_val = 0 + return self.attrs['dim0'] if 'dim0' in self.attrs else default_val + + @property + def dim1(self) -> int: + default_val = 0 + return self.attrs['dim1'] if 'dim1' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim0=0, dim1=0, extra_attrs=None): + op_name = op_name or opns.TRANSPOSE + assert op_name == opns.TRANSPOSE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim0': dim0, 'dim1': dim1}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) + + +@dataclass(init=False) +class BroadcastTo(symbol.Symbol): + op_name = opns.BROADCAST_TO + + @property + def newshape(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.BROADCAST_TO + assert op_name == opns.BROADCAST_TO + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + + +@dataclass(init=False) +class ExpandDims(symbol.Symbol): + op_name = opns.EXPAND_DIMS + + @property + def newshape(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.EXPAND_DIMS + assert op_name == opns.EXPAND_DIMS + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + + +@dataclass(init=False) +class Tile(symbol.Symbol): + op_name = opns.TILE + + @property + def dims(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['dims'] if 'dims' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dims=None, extra_attrs=None): + op_name = op_name or opns.TILE + assert op_name == opns.TILE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dims': dims}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + +def where(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.WHERE + assert op_name == opns.WHERE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def greater(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.GREATER + assert op_name == opns.GREATER + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class NonMaxSuppression(symbol.Symbol): + op_name = opns.NON_MAX_SUPRESSION + + @property + def iou_threshold(self) -> float: + default_val = 0.5 + return self.attrs['iou_threshold'] if 'iou_threshold' in self.attrs else default_val + @property + def score_threshold(self) -> typing.Optional[float]: + default_val = None + return self.attrs['score_threshold'] if 'score_threshold' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + op_name = op_name or opns.NON_MAX_SUPRESSION + assert op_name == opns.NON_MAX_SUPRESSION + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'iou_threshold': iou_threshold,'score_threshold':score_threshold}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + + +def ceil(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.CEIL + assert op_name == opns.CEIL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def rightShift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.RIGHT_SHIFT + assert op_name == opns.RIGHT_SHIFT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Add(symbol.Symbol): + op_name = opns.ADD + + @property + def alpha(self) -> int: + default_val = 1 + return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + op_name = op_name or opns.ADD + assert op_name == opns.ADD + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +@dataclass(init=False) +class Sub(symbol.Symbol): + op_name = opns.SUB + + @property + def alpha(self) -> int: + default_val = 1 + return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + op_name = op_name or opns.SUB + assert op_name == opns.SUB + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +def mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MUL + assert op_name == opns.MUL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +def matMul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MATMUL + assert op_name == opns.MATMUL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Div(symbol.Symbol): + op_name = opns.DIV + + @property + def rounding_mode(self) -> typing.Optional[str]: + default_val = None + return self.attrs['rounding_mode'] if 'rounding_mode' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): + op_name = op_name or opns.DIV + assert op_name == opns.DIV + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'rounding_mode': rounding_mode}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) + +def negative(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.NEGATIVE + assert op_name == opns.NEGATIVE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def abs(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ABS + assert op_name == opns.ABS + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def log(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.LOG + assert op_name == opns.LOG + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def sqrt(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SQRT + assert op_name == opns.SQRT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def pow(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.POW + assert op_name == opns.POW + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +def pass_(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.PASS + assert op_name == opns.PASS + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Arange(symbol.Symbol): + op_name = opns.ARANGE + + @property + def end(self) -> int: + default_val = 0 + return self.attrs['end'] if 'end' in self.attrs else default_val + + @property + def start(self) -> int: + default_val = 0 + return self.attrs['start'] if 'start' in self.attrs else default_val + + @property + def step(self) -> int: + default_val = 1 + return self.attrs['step'] if 'step' in self.attrs else default_val + + def __init__(self, name=None, op_name=None, end=0, start=0, step=1, extra_attrs=None): + op_name = op_name or opns.ARANGE + assert op_name == opns.ARANGE + super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={'end': end, 'start': start, 'step': step}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) + +def zerosLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ZEROS_LIKE + assert op_name == opns.ZEROS_LIKE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ONES_LIKE + assert op_name == opns.ONES_LIKE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + + +_register_op_map(opns.VAR)(var) +_register_op_map(opns.RELU)(relu) + _register_op_map(opns.CONV2D)(Conv2D) _register_op_map(opns.DROP_OUT)(Dropout) _register_op_map(opns.CLIP)(Clip) _register_op_map(opns.BATCH_NORM)(BatchNorm) _register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) +_register_op_map(opns.LEAKY_RELU)(LeakyRelu) + +_register_op_map(opns.MUL)(mul) +_register_op_map(opns.DENSE)(dense) +_register_op_map(opns.HARDTANH)(Hardtanh) +_register_op_map(opns.SILU)(silu) +_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(AdaptiveAvgPool2D) +_register_op_map(opns.AVG_POOL2D)(AvgPool2D) +_register_op_map(opns.MAX_POOL2D)(MaxPool2D) +_register_op_map(opns.SOFTMAX)(Softmax) +_register_op_map(opns.LOG_SOFTMAX)(LogSoftmax) +_register_op_map(opns.EXP)(exp) +_register_op_map(opns.SIGMOID)(sigmoid) +_register_op_map(opns.SUM)(Sum) +_register_op_map(opns.MEAN)(Mean) +_register_op_map(opns.MAX_AXIS)(MaxAxis) +_register_op_map(opns.MAXIMUM)(maximum) +_register_op_map(opns.MINIMUM)(minimum) + + +_register_op_map(opns.REPEAT)(repeat) +_register_op_map(opns.SQUEEZE)(Squeeze) +_register_op_map(opns.FLATTEN)(Flatten) +_register_op_map(opns.RESHAPE)(Reshape) +_register_op_map(opns.CONCAT)(Concat) +_register_op_map(opns.SPLIT)(Split) +_register_op_map(opns.TRANSPOSE)(Transpose) +_register_op_map(opns.BROADCAST_TO)(BroadcastTo) +_register_op_map(opns.EXPAND_DIMS)(ExpandDims) +_register_op_map(opns.TILE)(Tile) +_register_op_map(opns.WHERE)(where) +_register_op_map(opns.GREATER)(greater) +_register_op_map(opns.NON_MAX_SUPRESSION)(NonMaxSuppression) + +_register_op_map(opns.CEIL)(ceil) +_register_op_map(opns.RIGHT_SHIFT)(rightShift) + +_register_op_map(opns.ADD)(Add) +_register_op_map(opns.SUB)(Sub) +_register_op_map(opns.MATMUL)(matMul) +_register_op_map(opns.DIV)(Div) +_register_op_map(opns.NEGATIVE)(negative) +_register_op_map(opns.ABS)(abs) +_register_op_map(opns.LOG)(log) +_register_op_map(opns.SQRT)(sqrt) +_register_op_map(opns.POW)(pow) +_register_op_map(opns.PASS)(pass_) +_register_op_map(opns.ARANGE)(Arange) +_register_op_map(opns.ZEROS_LIKE)(zerosLike) +_register_op_map(opns.ONES_LIKE)(onesLike) + + # Add default register Class for MRT OP Not Implemented! -_register_op_map(opns.MUL)(symbol.Symbol) -_register_op_map(opns.DENSE)(symbol.Symbol) -_register_op_map(opns.RELU)(symbol.Symbol) -_register_op_map(opns.HARDTANH)(symbol.Symbol) -_register_op_map(opns.SILU)(symbol.Symbol) -_register_op_map(opns.LEAKY_RELU)(symbol.Symbol) -_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(symbol.Symbol) -_register_op_map(opns.AVG_POOL2D)(symbol.Symbol) -_register_op_map(opns.MAX_POOL2D)(symbol.Symbol) -_register_op_map(opns.SOFTMAX)(symbol.Symbol) -_register_op_map(opns.LOG_SOFTMAX)(symbol.Symbol) -_register_op_map(opns.EXP)(symbol.Symbol) -_register_op_map(opns.SIGMOID)(symbol.Symbol) -_register_op_map(opns.SUM)(symbol.Symbol) -_register_op_map(opns.MEAN)(symbol.Symbol) -_register_op_map(opns.MAX_AXIS)(symbol.Symbol) -_register_op_map(opns.MAXIMUM)(symbol.Symbol) -_register_op_map(opns.MINIMUM)(symbol.Symbol) -_register_op_map(opns.TUPLE)(symbol.Symbol) -_register_op_map(opns.REPEAT)(symbol.Symbol) -_register_op_map(opns.SQUEEZE)(symbol.Symbol) -_register_op_map(opns.FLATTEN)(symbol.Symbol) -_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) -_register_op_map(opns.RESHAPE)(symbol.Symbol) -_register_op_map(opns.CONCAT)(symbol.Symbol) -_register_op_map(opns.SPLIT)(symbol.Symbol) -_register_op_map(opns.TRANSPOSE)(symbol.Symbol) -_register_op_map(opns.BROADCAST_TO)(symbol.Symbol) -_register_op_map(opns.EXPAND_DIMS)(symbol.Symbol) -_register_op_map(opns.TILE)(symbol.Symbol) -_register_op_map(opns.WHERE)(symbol.Symbol) -_register_op_map(opns.GREATER)(symbol.Symbol) -_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) -_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) -_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) -_register_op_map(opns.NON_MAX_SUPRESSION)(symbol.Symbol) -_register_op_map(opns.CEIL)(symbol.Symbol) -_register_op_map(opns.RIGHT_SHIFT)(symbol.Symbol) -_register_op_map(opns.AS_TYPE)(symbol.Symbol) -_register_op_map(opns.ADV_INDEX)(symbol.Symbol) -_register_op_map(opns.CALL_TIR)(symbol.Symbol) -_register_op_map(opns.CALL_DPS_PACKED)(symbol.Symbol) -_register_op_map(opns.ADD)(symbol.Symbol) -_register_op_map(opns.SUB)(symbol.Symbol) -_register_op_map(opns.MATMUL)(symbol.Symbol) -_register_op_map(opns.DIV)(symbol.Symbol) -_register_op_map(opns.NEGATIVE)(symbol.Symbol) -_register_op_map(opns.ABS)(symbol.Symbol) -_register_op_map(opns.LOG)(symbol.Symbol) -_register_op_map(opns.SQRT)(symbol.Symbol) -_register_op_map(opns.POW)(symbol.Symbol) -_register_op_map(opns.PASS)(symbol.Symbol) -_register_op_map(opns.ARANGE)(symbol.Symbol) -_register_op_map(opns.ZEROS_LIKE)(symbol.Symbol) -_register_op_map(opns.ONES_LIKE)(symbol.Symbol) +_register_op_map(opns.TUPLE)(extern_opfunc(opns.TUPLE)) +_register_op_map(opns.AS_TYPE)(extern_opfunc(opns.AS_TYPE)) +_register_op_map(opns.ADV_INDEX)(extern_opfunc(opns.ADV_INDEX)) +_register_op_map(opns.CALL_TIR)(extern_opfunc(opns.CALL_TIR)) +_register_op_map(opns.CALL_DPS_PACKED)(extern_opfunc(opns.CALL_DPS_PACKED)) + _register_op_map(opns.IF)(symbol.Symbol) _register_op_map(opns.ARGWHERE)(symbol.Symbol) _register_op_map(opns.REQUANT)(symbol.Symbol) _register_op_map(opns.PCLIP)(symbol.Symbol) _register_op_map(opns.RS_PCLIP)(symbol.Symbol) _register_op_map(opns.LUT)(symbol.Symbol) + +_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) +_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) +_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) +_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) + diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index eab2d90..2e1136e 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -23,8 +23,8 @@ def test_create_conv2d_op(): - X = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") - W = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + X = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") assert [shp for shp in X.shape] == [shp for shp in (1, 3, 224, 224)], f'Wrong X shape {X.shape}' assert X.dtype == "float", f'Wrong X dtype {X.dtype}' @@ -72,11 +72,11 @@ def test_create_conv2d_op(): def test_create_symbol_graph(): - X0 = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") - W0 = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + X0 = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W0 = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") conv2d_a = opclass.Conv2D(X0, W0, name='conv2d_a', strides=(1,1)) - W1 = opclass.Variable(shape=(16, 3, 12, 12,), dtype="float") + W1 = opclass.var(shape=(16, 3, 12, 12,), dtype="float") conv2d_b = opclass.Conv2D(conv2d_a, W1, name='conv2d_b', strides=(1,1)) symlist = sx.sym2list(conv2d_b) @@ -90,11 +90,11 @@ def test_create_symbol_graph(): def test_create_batch_norm_op(): - X = opclass.Variable(name="x", shape=(1, 32, 128, 128,), dtype="float") - Gamma = opclass.Variable(name="gamma", shape=(32,), dtype="float") - Beta = opclass.Variable(name="beta", shape=(32,), dtype="float") - Mean = opclass.Variable(name="mean", shape=(32,), dtype="float") - Var = opclass.Variable(name="var", shape=(32,), dtype="float") + X = opclass.var(name="x", shape=(1, 32, 128, 128,), dtype="float") + Gamma = opclass.var(name="gamma", shape=(32,), dtype="float") + Beta = opclass.var(name="beta", shape=(32,), dtype="float") + Mean = opclass.var(name="mean", shape=(32,), dtype="float") + Var = opclass.var(name="var", shape=(32,), dtype="float") batch_norm_a = opclass.BatchNorm(X, Gamma, Beta, Mean, Var, axis=1, epsilon=1e-4) # attrs hint From a99ecb7e5101aed29c03a425ef5393c8bbb18320 Mon Sep 17 00:00:00 2001 From: corlfj Date: Tue, 23 Sep 2025 14:36:44 +0800 Subject: [PATCH 5/7] [mir]: add inferpass(with params) --- python/mrt/mir/simple_pass.py | 177 ++++++++++++++++++++++++++-------- python/mrt/mir/symbol.py | 5 +- tests/mir/test.simple_pass.py | 8 +- 3 files changed, 142 insertions(+), 48 deletions(-) diff --git a/python/mrt/mir/simple_pass.py b/python/mrt/mir/simple_pass.py index 2fc362e..a876b95 100644 --- a/python/mrt/mir/simple_pass.py +++ b/python/mrt/mir/simple_pass.py @@ -2,25 +2,27 @@ import typing from functools import wraps -from dataclasses import dataclass, fields +from dataclasses import dataclass -import mrt from mrt.common import config +#from mrt.runtime import inference from mrt.common.utils import * from mrt.common.types import * -from . import opns, opclass +from . import op, opns, opclass from . import symbol as _symbol # mrt op visits +@dataclass class SimplePass: symbol: _symbol.Symbol - def __init__(self, symbol: _symbol.Symbol): - self.symbol = symbol - - def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] = None) -> _symbol.Symbol: + """op-level visit of graph + infer different visit function with different op_name + return: head symbol processed + """ + def graph_visits(self) -> _symbol.Symbol: env: typing.Dict[str, _symbol.Symbol] = {} for sym in _symbol.sym2list(self.symbol): assert sym.name not in env, f'{sym.name} NotIn env!' @@ -28,7 +30,7 @@ def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] # Updating args as passed symbol in env_dict sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) assert isinstance(sym, _symbol.Symbol), sym - out = custom_func(sym) if custom_func else getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) + out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) out = out or sym assert isinstance(out, _symbol.Symbol), out env[sym.name] = out @@ -37,21 +39,72 @@ def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: return op + """custom visit of graph + calling custom_func for all op_name + return: head symbol processed + """ + def custom_visits(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: + with N(name): + if once: + return custom_run(self.symbol) + return _symbol.transform(self.symbol, custom_run) + # mrt op visits with params, variables +@dataclass class InferPass(SimplePass): params: ParametersT - def is_param(self, symbol: _symbol.Symbol) -> bool: - return symbol.op_name == opns.VAR and symbol.name in self.params - - def get_param(self, symbol: _symbol.Symbol) -> OpNumpyT: - assert self.is_param(symbol) - return self.params[symbol.name] if self.is_param(symbol) else [] - - def __init__(self, symbol: _symbol.Symbol, params: ParametersT): - self.symbol = symbol - self.params = params + def is_input(self, op_: _symbol.Symbol) -> bool: + return op.is_input(op_, self.params) + def is_variable(self, op_: _symbol.Symbol) -> bool: + return op.is_variable(op_, self.params) + def is_operator(self, op_: _symbol.Symbol) -> bool: + return op.is_operator(op_, self.params) + def is_param(self, op_: _symbol.Symbol) -> bool: + return op_.op_name == opns.VAR and op_.name in self.params + + def get_param(self, op_: _symbol.Symbol) -> OpNumpyT: + return self.params[op_.name] if self.is_param(op_) else [] + def get_as_numpy(self, op_: _symbol.Symbol) -> OpNumpyT: + assert self.is_param(op_), f"{op_.name} is not parameter." + data = self.params[op_.name] + assert isinstance(data, (tuple, list, np.ndarray)), \ + f"param:{op_.name} not OpNumpyT, get {type(data)}" + return data + + """custom visit of graph + calling custom_func for all op_name + return: head symbol processed + """ + def custom_visits_with_params(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: + with N(name): + if once: + return custom_run(self.symbol, self.params) + return _symbol.transform(self.symbol, custom_run, params=self.params) + + # From original quantization.Transformer + def as_parameter(self, data: OpNumpyT, name:str, dtype): + def _f(data, dtype): + if isinstance(data, list): + assert len(data) == len(dtype) + return [_f(d, t) for d, t in zip(data, dtype)] + assert isinstance(data, np.ndarray), type(data) + return data.astype(dtype) + array = _f(data, dtype) + shape = np.array(array).shape + self.params[name] = array + return opclass.var(array, shape=shape, dtype=dtype) + + def from_np_data(self, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: + name = N.n(prefix=prefix) + # some data is np.float/int type, use np.array to wrap it. + data = np.array(data) + self.params[name] = data.astype(dtype) + return opclass.var(name, shape=data.shape, dtype=dtype)#.like(self) + + def from_const_data(self, data: typing.Union[int, float], dtype) -> _symbol.Symbol: + return self.from_np_data(data, dtype) # Register MRT all op's default_visit_op function @@ -71,42 +124,83 @@ def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: class FuseTupleGetItemPass(SimplePass): - def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.TUPLE_GET_ITEM: - return sym - sym_ : opclass.TupleGetItem = sym - assert sym_.index == 0 - return sym_.args[0] - return sym - - -class FuseBatchNormPass(InferPass): - def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.BATCH_NORM: - X, Gamma, Beta, Mean, Var = sym.args - Gamma = self.get_param(Gamma) - Beta = self.get_param(Beta) - Mean = self.get_param(Mean) - Var = self.get_param(Var) - return sym + def visit_TupleGetItem(self, sym: opclass.TupleGetItem) -> _symbol.Symbol: + #if sym.op_name == opns.TUPLE_GET_ITEM: + # assert sym.index == 0 + # return sym.args[0] return sym -class FuseSoftmaxPass(SimplePass): +class FuseNaiveSoftmaxPass(SimplePass): def visit_nn_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.SOFTMAX: - return self.args[0] + return sym.args[0] return sym def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.LOG_SOFTMAX: - return self.args[0] + return sym.args[0] return sym -class FuseMeanPass(SimplePass): +class FuseMeanPass(InferPass): def visit_mean(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.MEAN: + X = sym.args[0] + out = opclass.Sum(X, **sym.attrs) + scale = self.from_np_data(np.array( + 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) + out = opclass.mul(out, scale) + return out + return sym + + +class FuseConstantPass(InferPass): + threshold: typing.ClassVar[float] = 1e-5 + + def np_is_zero(self, data) -> float: + return np.abs(data).max() < self.threshold + + + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol:#: _symbol._TransformerParamT + if self.is_operator(sym) and all([self.is_param(arg) for arg in sym.args]): + data = inference.run_single_params( + sym, [self.get_as_numpy(a) for a in sym.args]) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + elif sym.is_op(opns.ADD, opns.SUB): # , BIAS_ADD): + strips = [] + for arg in sym.args: + if self.is_param(arg) and self.np_is_zero(self.get_as_numpy(arg)): + strips.append(arg) + args = [a for a in sym.args if a not in strips] + if len(args) == 1: + return args[0] + elif sym.is_op(opns.SLICE_LIKE): + if not self.is_param(sym.args[0]): + return None + a, b = sym.args + data = inference.run_single_params( + sym, [self.get_as_numpy(a), np.zeros(b.shape, b.dtype)]) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + elif sym.is_op(opns.REQUANT): + if sym.rescale == 1: + return sym.args[0] + elif sym.is_op(opns.ZEROS_LIKE, opns.ONES_LIKE): + data = inference.run_single_params(sym, []) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + return sym + return custom_run + + +class FuseBatchNormPass(InferPass): + def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) return sym return sym @@ -117,8 +211,7 @@ def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: argA = sym.args[0] argB = sym.args[1] assert self.is_param(argB), f'NotParam: {argB}' - # TODO: fixit - #argB = argB.from_np_data(1. / argB.numpy()) + argB = self.from_np_data(1. / self.get_as_numpy(argB), dtype=argB.dtype) return opclass.MRT_OP_MAP[opns.MUL](argA, argB) return sym diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index c1e6bd7..315b70d 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -322,6 +322,7 @@ def load_json(data: _SymbolJsonT, **extra_attrs) -> Symbol: _VisitorT = typing.Callable[[Symbol], None] _TransformerT = typing.Callable[[Symbol], typing.Optional[Symbol]] +_TransformerParamT = typing.Callable[[Symbol, typing.Optional[ParametersT]], Symbol] """ Symbol Transformer Return new symbol to transform old symbol into updated one, @@ -338,7 +339,7 @@ def visit(symbol: Symbol, callback: _VisitorT): if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f">> {sym}") -def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: +def transform(symbol: Symbol, callback: _TransformerParamT, params:typing.Optional[ParametersT] = None) -> Symbol: """ Transform symbol from old to new, with inputs updated. Only the return value indicates mutation, while changing @@ -355,7 +356,7 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f"<< {sym}") - out = callback(sym) or sym + out = (callback(sym, params) if params else callback(sym)) or sym assert isinstance(out, Symbol), out # default const_ prefix symbol means parameters assert sym.name not in sym_map, sym.name diff --git a/tests/mir/test.simple_pass.py b/tests/mir/test.simple_pass.py index 9d20ee0..33139d4 100644 --- a/tests/mir/test.simple_pass.py +++ b/tests/mir/test.simple_pass.py @@ -56,7 +56,7 @@ def test_SimplePass_FuseDropout(mrt_graph, mrt_params): # init FuseDropout Passer and execute visit tfs : simple_pass.FuseDropoutPass = simple_pass.FuseDropoutPass(symbol) #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) - symbol_passed = tfs.visit() + symbol_passed = tfs.graph_visits() print('\n=== After FuseDropout Pass ===') rlts = sx.sym2list(symbol_passed) @@ -83,12 +83,12 @@ def test_SimplePass_CustomFunc(mrt_graph): tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) conv2d_name_list = [] - def _filter_op(sym: sx.Symbol) -> sx.Symbol: + def _filter_op(sym: sx.Symbol, params=None) -> sx.Symbol: if sym.op_name == opns.CONV2D: conv2d_name_list.append(sym.name) return sym - symbol_passed = tfs.visit(_filter_op) + symbol_passed = tfs.custom_visits(_filter_op) print('\n=== After CustomFunc Pass ===') assert len(conv2d_name_list) > 0 @@ -113,7 +113,7 @@ def _nn_dropout(sym: sx.Symbol) -> sx.Symbol: if sym.op_name == opns.DROP_OUT: return sym.args[0] return sym - symbol_passed = tfs.visit(_nn_dropout) + symbol_passed = tfs.custom_visits(_nn_dropout) print('\n=== After FuseDropout CustomFunc Pass ===') rlts = sx.sym2list(symbol_passed) From 677ea6cc4dd08add3ba6e61c48b8bb3e87a8a317 Mon Sep 17 00:00:00 2001 From: corlfj Date: Mon, 29 Sep 2025 17:46:02 +0800 Subject: [PATCH 6/7] [mir]: opclass redefine names, op compulsory attrs check --- python/mrt/mir/opclass.py | 237 +++++++++++++++++++++++++++--------- tests/mir/test.op_create.py | 54 +++++++- 2 files changed, 232 insertions(+), 59 deletions(-) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 2e4cdf9..e3b40f0 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -9,7 +9,7 @@ #SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") -SymbolCreator = typing.Union[typing.Callable[[typing.Any, typing.Any], typing.Type[symbol.Symbol]], SelfSymbol] +SymbolCreator = typing.Union[typing.Callable[[typing.Any, ...], typing.Type[symbol.Symbol]], SelfSymbol] #SymbolCreator = typing.Union[typing.Callable[[...], symbol.Symbol], SelfSymbol] MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} @@ -29,8 +29,9 @@ def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: # OPs from external (not in MRT op), using custom op_name with default op_func #y = extern_opfunc("tanh")(X) def extern_opfunc(op_name: str): - def op_func(*args, **attrs): - return symbol.Symbol(*args, op_name=op_name, **attrs) + def op_func(name, args, attrs, extra_attrs): + #return symbol.Symbol(op_name=op_name, *args, **attrs) + return symbol.Symbol(name, op_name, args, attrs, extra_attrs) return op_func @@ -91,19 +92,26 @@ def dilation(self) -> typing.Tuple[int, int]: @property def kernel_size(self) -> typing.Tuple[int, int]: - default_val = (3,3) - return self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + assert 'kernel_size' in self.attrs + return self.attrs['kernel_size'] # Follows (*args, name, **attrs) - def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_size=(3,3), extra_attrs=None): + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): op_name = op_name or opns.CONV2D assert op_name == opns.CONV2D + assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' + kernel_size = (W.shape[2], W.shape[3]) super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_size'], **kwargs) + # Auto inferred 'kernel_size' + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation'], **kwargs) + +def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): + return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, extra_attrs) + @dataclass(init=False) class Dropout(symbol.Symbol): @@ -123,29 +131,38 @@ def __init__(self, X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['p'], **kwargs) +def dropout(X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): + return Dropout(X, name, op_name, p, extra_attrs) + + @dataclass(init=False) class Clip(symbol.Symbol): op_name = opns.CLIP @property def min(self) -> float: - default_val = np.nan - return self.attrs['min'] if 'min' in self.attrs else default_val + assert 'min' in self.attrs + return self.attrs['min'] @property def max(self) -> float: - default_val = np.nan - return self.attrs['max'] if 'max' in self.attrs else default_val + assert 'max' in self.attrs + return self.attrs['max'] def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): op_name = op_name or opns.CLIP assert op_name == opns.CLIP + assert min_ != np.nan + assert max_ != np.nan super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min': min_, 'max': max_}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) +def clip(X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): + return Clip(X, name, op_name, min_, max_, extra_attrs) + @dataclass(init=False) class BatchNorm(symbol.Symbol): @@ -175,6 +192,9 @@ def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum'], **kwargs) +def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): + return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, extra_attrs) + @dataclass(init=False) class TupleGetItem(symbol.Symbol): @@ -194,6 +214,9 @@ def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['index'], **kwargs) +def tuple_get_item(X, name=None, op_name=None, index:int = 0, extra_attrs=None): + return TupleGetItem(X, name, op_name, index, extra_attrs) + @dataclass(init=False) class LeakyRelu(symbol.Symbol): @@ -213,6 +236,9 @@ def __init__(self, X, name=None, op_name=None, negative_slope:float = 1e-2, extr def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) +def leaky_relu(X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): + return LeakyRelu(X, name, op_name, negative_slope, extra_attrs) + def dense(X, W, B, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.DENSE @@ -242,6 +268,8 @@ def __init__(self, X, name=None, op_name=None, min_val:float = -1.0, max_val:flo def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) +def hard_tanh(X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + return Hardtanh(X, name, op_name, min_val, max_val, extra_attrs) @dataclass(init=False) class AdaptiveAvgPool2D(symbol.Symbol): @@ -249,33 +277,41 @@ class AdaptiveAvgPool2D(symbol.Symbol): @property def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: - default_val = 0 - return self.attrs['output_size'] if 'output_size' in self.attrs else default_val + assert 'output_size' in self.attrs + return self.attrs['output_size'] - def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=None, extra_attrs=None): op_name = op_name or opns.ADAPTIVE_AVG_POOL2D assert op_name == opns.ADAPTIVE_AVG_POOL2D + assert output_size != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'output_size': output_size}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['output_size'], **kwargs) +def adaptive_avg_pool2d(X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + return AdaptiveAvgPool2D(X, name, op_name, output_size, extra_attrs) + @dataclass(init=False) class AvgPool2D(symbol.Symbol): op_name = opns.AVG_POOL2D @property def pool_size(self) -> typing.Tuple[int, int]: - default_val = (2, 2) - return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] @property - def strides(self): - default_val = None + def strides(self) -> typing.Tuple[int, int]: + default_val = (0, 0) return self.attrs['strides'] if 'strides' in self.attrs else default_val @property - def padding(self) -> int: - default_val = 0 + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1, 1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0, 0, 0, 0) return self.attrs['padding'] if 'padding' in self.attrs else default_val @property def ceil_mode(self) -> bool: @@ -290,14 +326,19 @@ def count_include_pad(self) -> bool: default_val = True return self.attrs['count_include_pad'] if 'count_include_pad' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, pool_size=(2,2), strides=None, padding=0, ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): op_name = op_name or opns.AVG_POOL2D assert op_name == opns.AVG_POOL2D - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) + assert pool_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['pool_size', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + +def avg_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + return AvgPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, count_include_pad, extra_attrs) + @dataclass(init=False) class MaxPool2D(symbol.Symbol): @@ -305,21 +346,41 @@ class MaxPool2D(symbol.Symbol): @property def pool_size(self) -> typing.Tuple[int, int]: - default_val = (2, 2) - return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (0, 0) + return self.attrs['strides'] if 'strides' in self.attrs else default_val + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1, 1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0, 0, 0, 0) + return self.attrs['padding'] if 'padding' in self.attrs else default_val + @property + def ceil_mode(self) -> bool: + default_val = False + return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val @property def layout(self) -> str: default_val = 'NCHW' return self.attrs['layout'] if 'layout' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, pool_size=(2,2), layout='NCHW', extra_attrs=None): + def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): op_name = op_name or opns.MAX_POOL2D assert op_name == opns.MAX_POOL2D - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'layout':layout}, extra_attrs=extra_attrs or {}) + assert pool_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['pool_size', 'layout'], **kwargs) + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout'], **kwargs) + +def max_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): + return MaxPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, extra_attrs) @dataclass(init=False) @@ -340,6 +401,8 @@ def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) +def softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): + return Softmax(X, name, op_name, axis, extra_attrs) @dataclass(init=False) class LogSoftmax(symbol.Symbol): @@ -359,6 +422,9 @@ def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) +def log_softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): + return LogSoftmax(X, name, op_name, axis, extra_attrs) + def exp(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.EXP @@ -393,6 +459,9 @@ def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_att def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) +def sum(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return Sum(X, name, op_name, dim, keepdim, extra_attrs) + @dataclass(init=False) class Mean(symbol.Symbol): @@ -417,6 +486,9 @@ def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_att def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) +def mean(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return Mean(X, name, op_name, dim, keepdim, extra_attrs) + @dataclass(init=False) class MaxAxis(symbol.Symbol): @@ -441,6 +513,10 @@ def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_att def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) +def max_axis(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return MaxAxis(X, name, op_name, dim, keepdim, extra_attrs) + + def maximum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.MAXIMUM assert op_name == opns.MAXIMUM @@ -474,6 +550,9 @@ def __init__(self, X, name=None, op_name=None, dim=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim'], **kwargs) +def squeeze(X, name=None, op_name=None, dim=None, extra_attrs=None): + return Squeeze(X, name, op_name, dim, extra_attrs) + @dataclass(init=False) class Flatten(symbol.Symbol): op_name = opns.FLATTEN @@ -497,6 +576,9 @@ def __init__(self, X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_at def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) +def flatten(X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): + return Flatten(X, name, op_name, start_dim, end_dim, extra_attrs) + @dataclass(init=False) class Reshape(symbol.Symbol): @@ -504,18 +586,21 @@ class Reshape(symbol.Symbol): @property def newshape(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + assert 'newshape' in self.attrs + return self.attrs['newshape'] def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): op_name = op_name or opns.RESHAPE assert op_name == opns.RESHAPE + assert newshape != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) +def reshape(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return Reshape(X, name, op_name, newshape, extra_attrs) @dataclass(init=False) class Concat(symbol.Symbol): @@ -535,6 +620,8 @@ def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) +def concat(X, name=None, op_name=None, axis=None, extra_attrs=None): + return Concat(X, name, op_name, axis, extra_attrs) @dataclass(init=False) class Split(symbol.Symbol): @@ -542,23 +629,27 @@ class Split(symbol.Symbol): @property def split_size(self) -> typing.List[int]: - default_val = [] - return self.attrs['split_size'] if 'split_size' in self.attrs else default_val + assert 'split_size' in self.attrs + return self.attrs['split_size'] @property def dim(self) -> int: default_val = 0 return self.attrs['dim'] if 'dim' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, split_size=None, dim=0, extra_attrs=None): op_name = op_name or opns.SPLIT assert op_name == opns.SPLIT + assert split_size != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'split_size': split_size, 'dim': dim}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) +def split(X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + return Split(X, name, op_name, split_size, dim, extra_attrs) + @dataclass(init=False) class Transpose(symbol.Symbol): @@ -566,23 +657,28 @@ class Transpose(symbol.Symbol): @property def dim0(self) -> int: - default_val = 0 - return self.attrs['dim0'] if 'dim0' in self.attrs else default_val + assert 'dim0' in self.attrs + return self.attrs['dim0'] @property def dim1(self) -> int: - default_val = 0 - return self.attrs['dim1'] if 'dim1' in self.attrs else default_val + assert 'dim1' in self.attrs + return self.attrs['dim1'] - def __init__(self, X, name=None, op_name=None, dim0=0, dim1=0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): op_name = op_name or opns.TRANSPOSE assert op_name == opns.TRANSPOSE + assert dim0 != None + assert dim1 != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim0': dim0, 'dim1': dim1}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) +def transpose(X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): + return Transpose(X, name, op_name, dim0, dim1, extra_attrs) + @dataclass(init=False) class BroadcastTo(symbol.Symbol): @@ -590,18 +686,21 @@ class BroadcastTo(symbol.Symbol): @property def newshape(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + assert 'newshape' in self.attrs + return self.attrs['newshape'] def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): op_name = op_name or opns.BROADCAST_TO assert op_name == opns.BROADCAST_TO + assert newshape != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) +def broadcast_to(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return BroadcastTo(X, name, op_name, newshape, extra_attrs) @dataclass(init=False) class ExpandDims(symbol.Symbol): @@ -609,18 +708,21 @@ class ExpandDims(symbol.Symbol): @property def newshape(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + assert 'newshape' in self.attrs + return self.attrs['newshape'] def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): op_name = op_name or opns.EXPAND_DIMS assert op_name == opns.EXPAND_DIMS + assert newshape != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) +def expand_dims(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return ExpandDims(X, name, op_name, newshape, extra_attrs) @dataclass(init=False) class Tile(symbol.Symbol): @@ -628,18 +730,23 @@ class Tile(symbol.Symbol): @property def dims(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['dims'] if 'dims' in self.attrs else default_val + assert 'dims' in self.attrs + return self.attrs['dims'] def __init__(self, X, name=None, op_name=None, dims=None, extra_attrs=None): op_name = op_name or opns.TILE assert op_name == opns.TILE + assert dims != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dims': dims}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dims'], **kwargs) +def tile(X, name=None, op_name=None, dims=None, extra_attrs=None): + return Tile(X, name, op_name, dims, extra_attrs) + + def where(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.WHERE assert op_name == opns.WHERE @@ -672,13 +779,16 @@ def __init__(self, X, name=None, op_name=None, iou_threshold=0.5, score_threshol def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dims'], **kwargs) +def non_max_suppression(X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + return NonMaxSuppression(X, name, op_name, iou_threshold, score_threshold, extra_attrs) + def ceil(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.CEIL assert op_name == opns.CEIL return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) -def rightShift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def right_shift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.RIGHT_SHIFT assert op_name == opns.RIGHT_SHIFT return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) @@ -701,6 +811,9 @@ def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['alpha'], **kwargs) +def add(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + return Add(X, Y, name, op_name, alpha, extra_attrs) + @dataclass(init=False) class Sub(symbol.Symbol): op_name = opns.SUB @@ -719,12 +832,16 @@ def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['alpha'], **kwargs) +def sub(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + return Sub(X, Y, name, op_name, alpha, extra_attrs) + + def mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.MUL assert op_name == opns.MUL return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) -def matMul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def mat_mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.MATMUL assert op_name == opns.MATMUL return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) @@ -747,6 +864,10 @@ def __init__(self, X, Y, name=None, op_name=None, rounding_mode=None, extra_attr def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) +def div(X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): + return Div(X, Y, name, op_name, rounding_mode, extra_attrs) + + def negative(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.NEGATIVE assert op_name == opns.NEGATIVE @@ -783,8 +904,8 @@ class Arange(symbol.Symbol): @property def end(self) -> int: - default_val = 0 - return self.attrs['end'] if 'end' in self.attrs else default_val + assert 'end' in self.attrs + return self.attrs['end'] @property def start(self) -> int: @@ -796,21 +917,26 @@ def step(self) -> int: default_val = 1 return self.attrs['step'] if 'step' in self.attrs else default_val - def __init__(self, name=None, op_name=None, end=0, start=0, step=1, extra_attrs=None): + def __init__(self, name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): op_name = op_name or opns.ARANGE assert op_name == opns.ARANGE + assert end != None super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={'end': end, 'start': start, 'step': step}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) -def zerosLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def arange(name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): + return Arange(name, op_name, end, start, step, extra_attrs) + + +def zeros_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.ZEROS_LIKE assert op_name == opns.ZEROS_LIKE return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) -def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def ones_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.ONES_LIKE assert op_name == opns.ONES_LIKE return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) @@ -860,11 +986,11 @@ def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.NON_MAX_SUPRESSION)(NonMaxSuppression) _register_op_map(opns.CEIL)(ceil) -_register_op_map(opns.RIGHT_SHIFT)(rightShift) +_register_op_map(opns.RIGHT_SHIFT)(right_shift) _register_op_map(opns.ADD)(Add) _register_op_map(opns.SUB)(Sub) -_register_op_map(opns.MATMUL)(matMul) +_register_op_map(opns.MATMUL)(mat_mul) _register_op_map(opns.DIV)(Div) _register_op_map(opns.NEGATIVE)(negative) _register_op_map(opns.ABS)(abs) @@ -873,8 +999,8 @@ def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.POW)(pow) _register_op_map(opns.PASS)(pass_) _register_op_map(opns.ARANGE)(Arange) -_register_op_map(opns.ZEROS_LIKE)(zerosLike) -_register_op_map(opns.ONES_LIKE)(onesLike) +_register_op_map(opns.ZEROS_LIKE)(zeros_like) +_register_op_map(opns.ONES_LIKE)(ones_like) # Add default register Class for MRT OP Not Implemented! @@ -895,4 +1021,3 @@ def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) _register_op_map(opns.SLICE_LIKE)(symbol.Symbol) _register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) - diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index 2e1136e..7707afb 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -21,6 +21,20 @@ from mrt.mir import opclass +def test_op_func(): + X = opclass.var(name="var2", shape=(16, 128, 128), dtype="float") + ceil0 = opclass.ceil(X) + assert isinstance(ceil0, sx.Symbol), 'ceil0 isnot a symbol' + assert ceil0.op_name == opns.CEIL + assert len(ceil0.name) > 0 + + ceil1 = opclass.ceil(X, 'ceil_1') + assert ceil1.op_name == opns.CEIL + assert ceil1.name == 'ceil_1' + + return True + + def test_create_conv2d_op(): X = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") @@ -68,6 +82,10 @@ def test_create_conv2d_op(): assert isinstance(conv2d_d, opclass.Conv2D), 'conv2d_d isnot a Conv2D' assert isinstance(conv2d_e, opclass.Conv2D), 'conv2d_e isnot a Conv2D' + # alias function Init + conv2d_f = opclass.conv2d(*args, **attrs) + assert isinstance(conv2d_f, opclass.Conv2D), 'conv2d_f isnot a Conv2D' + return True @@ -104,7 +122,7 @@ def test_create_batch_norm_op(): # test clone mode batch_norm_b = batch_norm_a.copy() - assert isinstance(batch_norm_b , opclass.BatchNorm) + assert isinstance(batch_norm_b, opclass.BatchNorm) assert batch_norm_a.attrs == batch_norm_b.attrs, f'a: {batch_norm_a.attrs} != b: {batch_norm_b.attrs}' assert len(batch_norm_a.args) == len(batch_norm_b.args), f'a: {len(batch_norm_a.args)} != b: {len(batch_norm_b.args)}' @@ -112,6 +130,32 @@ def test_create_batch_norm_op(): return True +def test_create_reshape_op(): + X = opclass.var(name="x", shape=(16, 32, 64, 64,), dtype="float") + try: + reshape0 = opclass.Reshape(X, name="reshape_0") + assert False, "Reshape Must have attr 'newshape', Should already Fail!" + except: + pass + + reshape1 = opclass.Reshape(X, name="reshape_1", newshape=(16, 8, 128, 128)) + assert isinstance(reshape1, opclass.Reshape) + + return True + + +def test_op_extern_func(): + + # extern_func Do not need to fill 'op_name' + args = [opclass.var(name="var2", shape=(16, 128, 128), dtype="float")] + attrs = {} + extra_attrs = {} + call_dps_packed = opclass.MRT_OP_MAP[opns.CALL_DPS_PACKED]('packed_0', args, attrs, extra_attrs) + assert isinstance(call_dps_packed, sx.Symbol), 'call_dps_packed isnot a symbol' + assert call_dps_packed.op_name == opns.CALL_DPS_PACKED + return True + + if __name__ == "__main__": print('MRT_OP_SET as:', opclass.MRT_OP_MAP.keys()) assert len(opclass.MRT_OP_MAP.keys()) > 0 @@ -120,10 +164,14 @@ def test_create_batch_norm_op(): print('MRT_OP_MAP Conv2D Class as:', opclass.MRT_OP_MAP[opns.CONV2D]) test_id = 0 - for func_ in [test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op]: + passed_cnt = 0 + test_funcs = [test_op_func, test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op, test_create_reshape_op, test_op_extern_func] + for func_ in test_funcs: rltflag = func_() test_id += 1 + passed_cnt += rltflag print("\n" + "="*60 + "\n") - print(f'Passed Test{test_id}!' if rltflag else f'Test{test_id} Failed!') + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') From de5df5e81e426a2e5e949f5427e3487560b59a54 Mon Sep 17 00:00:00 2001 From: corlfj Date: Tue, 30 Sep 2025 14:57:07 +0800 Subject: [PATCH 7/7] [mir]: testing infer_pass, fix some opclass issue --- python/mrt/mir/opclass.py | 34 ++++-- python/mrt/mir/simple_pass.py | 192 +++++++++++++++++++++++++----- python/mrt/mir/symbol.py | 5 +- tests/mir/test.infer_pass.py | 103 ++++++++++++++++ tests/mir/test.infer_pass_div.py | 88 ++++++++++++++ tests/mir/test.infer_pass_mean.py | 89 ++++++++++++++ 6 files changed, 468 insertions(+), 43 deletions(-) create mode 100644 tests/mir/test.infer_pass.py create mode 100644 tests/mir/test.infer_pass_div.py create mode 100644 tests/mir/test.infer_pass_mean.py diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index e3b40f0..02fb929 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -95,22 +95,26 @@ def kernel_size(self) -> typing.Tuple[int, int]: assert 'kernel_size' in self.attrs return self.attrs['kernel_size'] + @property + def kernel_layout(self) -> str: + default_val = 'OIHW' + return self.attrs['kernel_layout'] if 'kernel_layout' in self.attrs else default_val # Follows (*args, name, **attrs) - def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): op_name = op_name or opns.CONV2D assert op_name == opns.CONV2D assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' kernel_size = (W.shape[2], W.shape[3]) - super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size, 'kernel_layout': kernel_layout}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): # Auto inferred 'kernel_size' - return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation'], **kwargs) + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_layout'], **kwargs) -def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): - return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, extra_attrs) +def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, kernel_layout, extra_attrs) @dataclass(init=False) @@ -181,19 +185,29 @@ def epsilon(self) -> float: @property def momentum(self) -> float: default_val = 0.1 + return self.attrs['momentum'] if 'momentum' in self.attrs else default_val + + @property + def center(self) -> bool: + default_val = True return self.attrs['center'] if 'center' in self.attrs else default_val - def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): + @property + def scale(self) -> bool: + default_val = True + return self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): op_name = op_name or opns.BATCH_NORM assert op_name == opns.BATCH_NORM - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum'], **kwargs) + return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum', 'center', 'scale'], **kwargs) -def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): - return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, extra_attrs) +def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): + return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, center, scale, extra_attrs) @dataclass(init=False) diff --git a/python/mrt/mir/simple_pass.py b/python/mrt/mir/simple_pass.py index a876b95..302da1b 100644 --- a/python/mrt/mir/simple_pass.py +++ b/python/mrt/mir/simple_pass.py @@ -75,6 +75,7 @@ def get_as_numpy(self, op_: _symbol.Symbol) -> OpNumpyT: """custom visit of graph calling custom_func for all op_name + according to how custom_run implemented, params is from argument or class_property return: head symbol processed """ def custom_visits_with_params(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: @@ -96,15 +97,15 @@ def _f(data, dtype): self.params[name] = array return opclass.var(array, shape=shape, dtype=dtype) - def from_np_data(self, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: + def from_np_data(self, sym:_symbol.Symbol, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: name = N.n(prefix=prefix) # some data is np.float/int type, use np.array to wrap it. data = np.array(data) self.params[name] = data.astype(dtype) - return opclass.var(name, shape=data.shape, dtype=dtype)#.like(self) + return opclass.var(name, shape=data.shape, dtype=dtype).like(sym) - def from_const_data(self, data: typing.Union[int, float], dtype) -> _symbol.Symbol: - return self.from_np_data(data, dtype) + def from_const_data(self, sym:_symbol.Symbol, data: typing.Union[int, float], dtype) -> _symbol.Symbol: + return self.from_np_data(sym, data, dtype) # Register MRT all op's default_visit_op function @@ -144,15 +145,17 @@ def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: class FuseMeanPass(InferPass): - def visit_mean(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.MEAN: - X = sym.args[0] - out = opclass.Sum(X, **sym.attrs) - scale = self.from_np_data(np.array( - 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) - out = opclass.mul(out, scale) - return out - return sym + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.MEAN: + X = sym.args[0] + out = opclass.Sum(X, **sym.attrs).like(sym) + scale = self.from_np_data(sym, np.array( + 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) + out = opclass.mul(out, scale) + return out + return sym + return custom_run class FuseConstantPass(InferPass): @@ -161,9 +164,8 @@ class FuseConstantPass(InferPass): def np_is_zero(self, data) -> float: return np.abs(data).max() < self.threshold - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol:#: _symbol._TransformerParamT + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: if self.is_operator(sym) and all([self.is_param(arg) for arg in sym.args]): data = inference.run_single_params( sym, [self.get_as_numpy(a) for a in sym.args]) @@ -178,7 +180,7 @@ def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) return args[0] elif sym.is_op(opns.SLICE_LIKE): if not self.is_param(sym.args[0]): - return None + return sym a, b = sym.args data = inference.run_single_params( sym, [self.get_as_numpy(a), np.zeros(b.shape, b.dtype)]) @@ -194,24 +196,150 @@ def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) class FuseBatchNormPass(InferPass): - def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.BATCH_NORM: - X, Gamma, Beta, Mean, Var = sym.args - Gamma = self.get_param(Gamma) - Beta = self.get_param(Beta) - Mean = self.get_param(Mean) - Var = self.get_param(Var) + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: opclass.BatchNorm, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) + + assert sym.axis == 1 + Beta = Beta if sym.center else 0 + Gamma = Gamma if sym.scale else 1 + + # (x - mean) / sqrt(var + epsilon) * gamma + beta + Gamma = Gamma / np.sqrt(Var + sym.epsilon) + # (x - mean) * gamma + beta + # x * gamma + (beta - mean * gamma) + bias: np.ndarray = (Beta - Mean * Gamma) + K = Gamma.shape[0] + + if X.is_op(opns.CONV2D): + A, W = X.args + assert X.kernel_layout == "OIHW" + assert W.shape[0] == K + # (A * W) * gamma + bias + # A * (W * gamma) + bias + W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1, 1, 1) + W_sym = self.from_np_data(W, W_data, W.dtype) + out = op.nn_conv2d(A, W_sym, **X.attrs) + elif X.is_op(opns.DENSE): + A, W = X.args + # (A * W) * gamma + bias + # A * (W * gamma) + bias + W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1) + W_sym = self.from_np_data(W, W_data, W.dtype) + out = op.nn_dense(A, W_sym, **X.attrs) + else: + reshp = [s if i == sym.axis else 1 \ + for i, s in enumerate(X.shape)] + W = self.from_np_data(X, Gamma.reshape(reshp), X.dtype) + out = opclass.mul(X, W) + + bias = bias.reshape([s if i == sym.axis else 1 \ + for i, s in enumerate(out.shape)]) + B = out.like(sym) + B = self.from_np_data(B, bias, dtype=B.dtype) + return opclass.add(out, B).like(sym) + return sym - return sym + return custom_run class FuseDividePass(InferPass): - def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.DIV: - argA = sym.args[0] - argB = sym.args[1] - assert self.is_param(argB), f'NotParam: {argB}' - argB = self.from_np_data(1. / self.get_as_numpy(argB), dtype=argB.dtype) - return opclass.MRT_OP_MAP[opns.MUL](argA, argB) - return sym + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.DIV: + argA = sym.args[0] + argB = sym.args[1] + assert self.is_param(argB), f'NotParam: {argB}' + argB = self.from_np_data(sym, 1. / self.get_as_numpy(argB), dtype=argB.dtype) + out = opclass.mul(argA, argB) + return out.like(sym) + return sym + return custom_run + +class FuseLeakyReLU(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.LEAKY_RELU: + alpha = self.from_const_data(sym, sym.alpha, dtype=float) + X = sym.args[0] + out = opclass.relu(opclass.negative(X)) + out = opclass.mul(alpha, out) + return opclass.sub(opclass.relu(X), out) + return sym + return custom_run + +class FuseAdaptiveAvgPool2D(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.ADAPTIVE_AVG_POOL2D: + X = sym.args[0] + assert sym.layout == "NCHW" + inp_shap = X.shape[2:] + out_size = sym.output_size or inp_shap + if not isinstance(out_size, (list, tuple)): + out_size = (out_size, out_size) + sym.output_size = out_size + + assert len(X.shape) == 4 + if all([s == 1 for s in sym.output_size]): + scale = np.array(1 / np.prod(X.shape[-2:])) + out = opclass.Sum(X, dim=list(range(4))[-2:], keepdims=True) + scale = self.from_np_data(sym, scale.astype(X.dtype)) + return opclass.mul(out, scale).like(self) + elif out_size[0] > inp_shap[0] or out_size[1] > inp_shap[1]: + assert all([s == 1 for s in inp_shap]) + # TODO: fix opclass repeat + out = opclass.repeat(X, repeats=out_size[0], axis=-2) + out = opclass.repeat(out, repeats=out_size[1], axis=-1) + return out.like(self) + + # calculate the attributes refers to: + # https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work + strides = [i // o for i, o in zip(inp_shap, out_size)] + kernel = [i-(o-1)*s for i, o, s in zip(inp_shap, out_size, strides)] + attrs = { + "kernel_size": kernel, + "strides": strides, + "padding": (0, 0), + "dilation": (1, 1), + "data_layout": sym.layout, + "groups": X.shape[1], + "channels": X.shape[1], + } + W_shape = (X.shape[1], 1, *kernel) + W = self.from_np_data(X, np.full(W_shape, 1 / product(kernel)), dtype=X.dtype) + out = opclass.Conv2D(X, W, **attrs) + return out.like(sym) + return sym + return custom_run + + +class FuseAvgPool2D(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Spliter(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Merger(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Calibrator(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 315b70d..07bd8af 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -113,7 +113,10 @@ def like(self, other: Symbol, **kwargs) -> Symbol: # assert self.shape == other.shape, "%s vs.\n %s" % (self, other) # assert self.dtype == other.dtype , "%s vs.\n %s" % (self, other) data = other.to_dict() - data.update(self.to_dict()) + data_new = self.to_dict() + data.update(data_new) + + data["extra_attrs"] = other.extra_attrs if self.extra_attrs == {} else data["extra_attrs"] # copy extra attrs by default. # data["extra_attrs"] = other.extra_attrs return type(other).from_dict(data, **kwargs) diff --git a/tests/mir/test.infer_pass.py b/tests/mir/test.infer_pass.py new file mode 100644 index 0000000..3d94e93 --- /dev/null +++ b/tests/mir/test.infer_pass.py @@ -0,0 +1,103 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_resnet18_model(): + """Get Resnet18 MRT Model""" + + # Load pre-trained ResNet18 + model = models.resnet18(weights='IMAGENET1K_V1') + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseBatchNorm(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseBatchNorm Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseAdaptiveAvgPool2D(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseAdaptiveAvgPool2D Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseTupleGetItem(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseTuple Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt > 0, f'ori model TupleGetItem op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseTupleGetItemPass = simple_pass.FuseTupleGetItemPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseTuple Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass ===") + mrt_graph, mrt_params = _get_resnet18_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseBatchNorm, test_InferPass_FuseAdaptiveAvgPool2D, test_InferPass_FuseTupleGetItem] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_div.py b/tests/mir/test.infer_pass_div.py new file mode 100644 index 0000000..547363c --- /dev/null +++ b/tests/mir/test.infer_pass_div.py @@ -0,0 +1,88 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_fasterrcnn_resnet50_fpn_model(): + """Get Fasterrcnn_resnet50_fpn MRT Model""" + + # Load pre-trained model + model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseDivide(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDivide Pass ===') + symlist = sx.sym2list(symbol) + + divide_op_cnt = 0 + for sym in symlist: + divide_op_cnt += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt > 0, f'ori model divide op cnt {divide_op_cnt} == zero!' + + # init FuseDivide Passer and execute visit + tfs : simple_pass.FuseDividePass = simple_pass.FuseDividePass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseDivide Pass ===') + rlts = sx.sym2list(symbol_passed) + divide_op_cnt_af = 0 + for sym in rlts: + # print(sym) + divide_op_cnt_af += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt_af==0, f'passed model divide op cnt {divide_op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Divide ===") + mrt_graph, mrt_params = _get_fasterrcnn_resnet50_fpn_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseDivide] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_mean.py b/tests/mir/test.infer_pass_mean.py new file mode 100644 index 0000000..d8586ec --- /dev/null +++ b/tests/mir/test.infer_pass_mean.py @@ -0,0 +1,89 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_shufflenet_model(): + """Get Shufflenet MRT Model""" + + # Load pre-trained + model = models.shufflenet_v2_x1_0(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseMean(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseMean Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt > 0, f'ori model mean op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseMeanPass = simple_pass.FuseMeanPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseMean Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Mean ===") + mrt_graph, mrt_params = _get_shufflenet_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseMean] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') +