diff --git a/python/mrt/api.py b/python/mrt/api.py index fe511cb..f095be0 100644 --- a/python/mrt/api.py +++ b/python/mrt/api.py @@ -12,7 +12,7 @@ from .runtime.analysis import * from .mir import op, helper -# from .mir.model import MultiHeadSymbol +from .mir.mhsymbol import MultiHeadSymbol from .mir.symbol import * from .dataset.base import Dataset @@ -26,7 +26,7 @@ from .quantization.discrete import Discretor from .quantization.precision import PrecisionRevisor -from .quantization.transform import TransformerT +from .mir.symbol_pass import SymTransformerT @dataclass class TraceConfig(config._BaseConfig): @@ -174,7 +174,7 @@ def _new(self, tr_name: str, _stat_type = self._stat_type) def checkpoint_run(self, - *callbacks: typing.List[TransformerT], + *callbacks: typing.List[SymTransformerT], tr_name: typing.Optional[str] = None, **kwargs) -> Trace: C = TraceConfig.G() @@ -200,7 +200,7 @@ def checkpoint_run(self, for cb in callbacks: # deep copy params to avoid conflict status params = {k: v for k, v in out.params.items()} - print("Apply Trace: {:25} Transformer: {}".format( + print("Apply Trace: {:25} SymbolTransformer: {}".format( tr_name, cb.__name__)) if cb.__name__ in C.log_before_tr_or_cbs: @@ -223,7 +223,14 @@ def checkpoint_run(self, def discrete(self) -> Trace: fuse_tr = self.fuse() + + """Must pass params inside a dict, + Cause it will be unfolded separately + """ seg_tr = fuse_tr.checkpoint_run(seg.Spliter.get_transformer()) + kwargs_seg = {"ptr": {"head": seg_tr.symbol.extra_attrs.get("head"), + "head_params": seg_tr.symbol.extra_attrs.get("head_params"), + "seg_names": seg_tr.symbol.extra_attrs.get("seg_names")}} C = TraceConfig.G() calib_tr = seg_tr.calibrate( @@ -232,7 +239,8 @@ def discrete(self) -> Trace: quant_tr = calib_tr.quantize() quant_tr = quant_tr.checkpoint_run( seg.Merger.get_transformer(), - spliter=seg_tr.symbol) + spliter=seg_tr.symbol, + **kwargs_seg) return quant_tr def fuse(self, **kwargs) -> Trace: @@ -247,6 +255,7 @@ def fuse(self, **kwargs) -> Trace: fuse.FuseDropout.get_transformer(), fuse.FuseMean.get_transformer(), fuse.FuseNaiveSoftmax.get_transformer(), + fuse.FuseIdentity.get_transformer(), fuse.FuseConstant.get_transformer(), **kwargs, ) @@ -254,13 +263,13 @@ def fuse(self, **kwargs) -> Trace: def calibrate(self, repeats: int = 1, **kwargs) -> Trace: assert self._dataset is not None tr_name = kwargs.pop("tr_name", "calibrate") + out = self for i in range(repeats): data, _ = self._dataset.next() out = out.checkpoint_run( calib.Calibrator.get_transformer(), data = data, - # tr_name = tr_name, tr_name = f"{tr_name}_run_{i}", **kwargs) out = out.checkpoint_run( diff --git a/python/mrt/frontend/api.py b/python/mrt/frontend/api.py index 0b55096..c9c9692 100644 --- a/python/mrt/frontend/api.py +++ b/python/mrt/frontend/api.py @@ -4,6 +4,7 @@ from functools import wraps from mrt.mir.symbol import * +from mrt.mir.mhsymbol import MultiHeadSymbol, Graph from mrt.common.types import * from mrt.common.config import MRTConfig diff --git a/python/mrt/frontend/expr.py b/python/mrt/frontend/expr.py index 9515db4..ff8d089 100644 --- a/python/mrt/frontend/expr.py +++ b/python/mrt/frontend/expr.py @@ -14,6 +14,7 @@ from ..symbol import * from ..types import * from .. import op +from .. import opclass __ALL__ = [ "expr2symbol", "symbol2expr", "tvm_type_infer" ] @@ -62,7 +63,7 @@ def _cast_expr(node: RelayExpr): if isinstance(node, relay.expr.Constant): name = N.n("const_") params[name] = node.data - symbol_map[node] = op.variable(name, + symbol_map[node] = opclass.var(name, node.data.shape, node.data.dtype) return @@ -85,11 +86,11 @@ def _cast_expr(node: RelayExpr): if isinstance(node, relay.expr.Var): name = node.name_hint or N.n(prefix="input_") - symbol_map[node] = op.variable(name, shape, dtype) + symbol_map[node] = opclass.var(name, shape, dtype) elif isinstance(node, relay.expr.If): args = [ node.cond, node.true_branch, node.false_branch ] args = [symbol_map[i] for i in args] - symbol_map[node] = op._new_op(IF, *args, **attrs) + symbol_map[node] = opclass.extern_opfunc(IF)(*args, **attrs) elif isinstance(node, relay.expr.Call): op_name = node.op.name if op_name in [CONCAT, ADV_INDEX]: @@ -108,15 +109,14 @@ def _cast_expr(node: RelayExpr): attrs.pop("dtype") elif op_name == GET_VALID_COUNT: attrs.pop("score_threshold") - symbol_map[node] = op._new_op(op_name, *args, **attrs) + symbol_map[node] = opclass.extern_opfunc(op_name)(*args, **attrs) elif isinstance(node, relay.TupleGetItem): args = [ symbol_map[node.tuple_value], ] attrs['index'] = node.index - symbol_map[node] = op._new_op( - TUPLE_GET_ITEM, *args, **attrs) + symbol_map[node] = opclass.extern_opfunc(TUPLE_GET_ITEM)(*args, **attrs) elif isinstance(node, relay.Tuple): args = [ symbol_map[f] for f in node.fields ] - symbol_map[node] = op._new_op(TUPLE, *args, **attrs) + symbol_map[node] = opclass.extern_opfunc(TUPLE)(*args, **attrs) else: raise RuntimeError( "MRT not support expr type:{}".format(type(node))) diff --git a/python/mrt/frontend/pytorch/converter.py b/python/mrt/frontend/pytorch/converter.py index 2318fc6..8c5ddc8 100644 --- a/python/mrt/frontend/pytorch/converter.py +++ b/python/mrt/frontend/pytorch/converter.py @@ -9,8 +9,9 @@ import torch.nn.functional as F import sys -from mrt.mir.symbol import Symbol, MultiHeadSymbol, sym2list, transform -from mrt.mir import op +from mrt.mir.symbol import Symbol, sym2list, transform +from mrt.mir.mhsymbol import MultiHeadSymbol +from mrt.mir import op, opclass from mrt.mir.opns import * from mrt.common.types import ParametersT from mrt.common.utils import N @@ -46,7 +47,7 @@ class _T: "adaptive_avg_pool2d.default": _T(ADAPTIVE_AVG_POOL2D, 1, [ Attr("output_size", (1,1)) ]), "max_pool2d.default": _T(MAX_POOL2D, 1, [ - Attr("kernel_size", (1,1)), Attr("strides", (1,1)), Attr("padding", (0,0)) ]), + Attr("kernel_size", (1,1)), Attr("strides", (1,1)), Attr("padding", (0,0)), Attr("dilation", (1,1)), Attr("ceil_mode", False) ]), "mean.dim": _T(MEAN, 1, [ Attr("dim", None), Attr("keepdim", False) ]), "add.Tensor": _T(ADD, 2), "add_.Tensor": _T(ADD, 2), @@ -60,7 +61,7 @@ class _T: "cat.default": _T(CONCAT, 1, [ Attr("dim", 0) ]), "view.default": _T(RESHAPE, 1, [ Attr("shape", ()) ]), "transpose.int": _T(TRANSPOSE, 1, [ Attr("dim0", 0), Attr("dim1", 0) ]), - "contiguous.default": _T(PASS, 1), + "contiguous.default": _T(IDENTITY, 1), "chunk.default": _T(SPLIT, 1, [ Attr("chunks", 1), Attr("dim", 0) ]), "getitem": _T(TUPLE_GET_ITEM, 1, [ Attr("index", 0) ]), @@ -100,7 +101,7 @@ class _T: ), RESHAPE: torch.reshape, TRANSPOSE: torch.transpose, - PASS: lambda x: x, + IDENTITY: lambda x: x, SPLIT: torch.chunk, ADD: torch.add, @@ -156,7 +157,7 @@ def create_parameters(ep: torch.export.ExportedProgram): dshape = data_to_mrt(torch_shape) dtype = data_to_mrt(torch_dtype) - out = op.variable(name_hint, dshape, dtype) + out = opclass.var(name=name_hint, shape=dshape, dtype=dtype) params[name_hint] = to_bind_parameters[spec.target].detach().numpy().astype(dtype) assert dshape == list(params[name_hint].shape) # print(">> vars: ", out) @@ -207,7 +208,7 @@ def _retrieve_args(node): continue if node.name not in param_vars: # input - env[node] = op.variable(node.name, shape, dtype) + env[node] = opclass.var(name=node.name, shape=shape, dtype=dtype) else: env[node] = param_vars[node.name] elif node.op == "output": # [[ out1, out2, out3 ]] @@ -234,13 +235,15 @@ def _retrieve_args(node): if mapper.op_name == CONCAT: args = args[0] + if mapper.op_name == SPLIT: + shape = data_to_mrt([ t.shape for t in node.meta['val']]) + dtype = data_to_mrt([ t.dtype for t in node.meta['val']]) + if mapper.op_name == TUPLE_GET_ITEM and args[0].op_name == BATCH_NORM: out = args[0] else: - out = op._new_op( - mapper.op_name, *args, - name=node.name, extra_attrs={ "shape": shape, "dtype": dtype }, - **attrs) + out = opclass.extern_opfunc(mapper.op_name)(*args, name=node.name, + extra_attrs={"shape": shape, "dtype": dtype}, **attrs) env[node] = out else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/mrt/frontend/pytorch/vm.py b/python/mrt/frontend/pytorch/vm.py index 31c44e8..bf353a2 100644 --- a/python/mrt/frontend/pytorch/vm.py +++ b/python/mrt/frontend/pytorch/vm.py @@ -6,6 +6,7 @@ from .types import * from mrt.mir.symbol import * +from mrt.mir.mhsymbol import MultiHeadSymbol from mrt.common.types import * Executor = namedtuple("Executor", ["vm", "device"]) diff --git a/python/mrt/mir/mhsymbol.py b/python/mrt/mir/mhsymbol.py new file mode 100644 index 0000000..f2b5ada --- /dev/null +++ b/python/mrt/mir/mhsymbol.py @@ -0,0 +1,35 @@ +import typing + +from mrt.common.utils import * +from mrt.common.types import * + +from . import opns, opclass, optype +from . import symbol + +#from mrt.mir.mhsymbol import MultiHeadSymbol, Graph +class MultiHeadSymbol(dict): + """ { "main": F(X) } """ + origin: typing.Optional[symbol.Symbol] = None + + @classmethod + def from_symbol(cls, symbol: symbol.Symbol, name: str = "main"): + return MultiHeadSymbol({ name: symbol }) + + def as_tuple(self) -> typing.Tuple[typing.List[str], symbol.Symbol]: + # args = list(self.values()) + # sym_type = type(args[0]) if args else Symbol + mhs = self.origin or optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*list(self.values()))) + return list(self.keys()), mhs + + @classmethod + def from_tuple(cls, tuple_names, symbol): + assert symbol.is_op(opns.TUPLE), symbol + mhs = cls(zip(tuple_names, symbol.args)) + mhs.origin = symbol + return mhs + +Graph = typing.Union[symbol.Symbol, MultiHeadSymbol] +""" Notice that Symbol and MultiHeadSymbol can both + be regarded as a model Graph. +""" + diff --git a/python/mrt/mir/op.py b/python/mrt/mir/op.py index 1eab62b..ef2cab3 100644 --- a/python/mrt/mir/op.py +++ b/python/mrt/mir/op.py @@ -32,7 +32,8 @@ def variable(name, shape, dtype) -> Symbol: def as_variable(symbol: Symbol, shape=None, dtype=None) -> Symbol: """ inherit extra attrs """ - out = symbol.copy(op_name=VAR, args=[], attrs={}) + # out = symbol.copy(op_name=VAR, args=[], attrs={}) + out = symbol.as_variable() out.shape = shape or out.shape out.dtype = dtype or out.dtype return out @@ -40,69 +41,68 @@ def as_variable(symbol: Symbol, shape=None, dtype=None) -> Symbol: def retrieve_operator(symbol: Symbol) -> Symbol: return symbol.copy(args=[as_variable(c) for c in symbol.args]) -def _new_op(op_name, *args, extra_attrs=None, **attrs) -> Symbol: - name = attrs.pop("name", N.n()) - return Symbol.from_dict({}, - name=name, op_name=op_name, - args=args or [], attrs=attrs or {}, - extra_attrs=extra_attrs or {}) - -def _register_op(op_name): - def _op(*args, **attrs) -> Symbol: - op = _new_op(op_name, *args, **attrs) - from . import optype - out = optype.infer_single(op) - return out - return _op - -Tuple = _register_op(TUPLE) -TupleGetItem = _register_op(TUPLE_GET_ITEM) - -# class Conv2D(Symbol): -# strides: - -# TODO: define op function -# def conv2d(X, weight, bias, strides=(1,1)...): -# return Symbol(args=[X, weight, bias], -# attrs={ "strides": strides }) -nn_conv2d = _register_op(CONV2D) -nn_dense = _register_op(DENSE) -nn_batch_norm = _register_op(BATCH_NORM) -# bias_add = _register_op(BIAS_ADD) - -nn_relu = _register_op(RELU) - -sum = _register_op(SUM) -# mean = _register_op(MEAN) -clip = _register_op(CLIP) -ceil = _register_op(CEIL) -right_shift = _register_op(RIGHT_SHIFT) -# relax api from cast to astype -# astype = _register_op(AS_TYPE) -cast = _register_op(AS_TYPE) -# flatten = _register_op(FLATTEN) -adv_index = _register_op(ADV_INDEX) -zeros_like = _register_op(ZEROS_LIKE) - -repeat = _register_op(REPEAT) -reshape = _register_op(RESHAPE) - -add = _register_op(ADD) -sub = _register_op(SUB) -max_axis = _register_op(MAX_AXIS) -mul = _register_op(MUL) -div = _register_op(DIV) -matmul = _register_op(MATMUL) -exp = _register_op(EXP) -negative = _register_op(NEGATIVE) - -sigmoid = _register_op(SIGMOID) -softmax = _register_op(SOFTMAX) - -requant = _register_op(REQUANT) -pclip = _register_op(PCLIP) -rs_pclip = _register_op(RS_PCLIP) -lut = _register_op(LUT) +# def _new_op(*args, op_name='', extra_attrs=None, **attrs) -> Symbol: +# name = attrs.pop("name", N.n()) +# return Symbol(*args, +# name=name, op_name=op_name, +# extra_attrs=extra_attrs or {}, +# **attrs) +# +# def _register_op(op_name): +# def _op(*args, **attrs) -> Symbol: +# op = _new_op(*args, op_name=op_name, **attrs) +# from . import optype +# out = optype.infer_single(op) +# return out +# return _op +# +# Tuple = _register_op(TUPLE) +# TupleGetItem = _register_op(TUPLE_GET_ITEM) +# +# # class Conv2D(Symbol): +# # strides: +# +# # def conv2d(X, weight, bias, strides=(1,1)...): +# # return Symbol(args=[X, weight, bias], +# # attrs={ "strides": strides }) +# nn_conv2d = _register_op(CONV2D) +# nn_dense = _register_op(DENSE) +# nn_batch_norm = _register_op(BATCH_NORM) +# # bias_add = _register_op(BIAS_ADD) +# +# nn_relu = _register_op(RELU) +# +# sum = _register_op(SUM) +# # mean = _register_op(MEAN) +# clip = _register_op(CLIP) +# ceil = _register_op(CEIL) +# right_shift = _register_op(RIGHT_SHIFT) +# # relax api from cast to astype +# # astype = _register_op(AS_TYPE) +# cast = _register_op(AS_TYPE) +# # flatten = _register_op(FLATTEN) +# adv_index = _register_op(ADV_INDEX) +# zeros_like = _register_op(ZEROS_LIKE) +# +# repeat = _register_op(REPEAT) +# reshape = _register_op(RESHAPE) +# +# add = _register_op(ADD) +# sub = _register_op(SUB) +# max_axis = _register_op(MAX_AXIS) +# mul = _register_op(MUL) +# div = _register_op(DIV) +# matmul = _register_op(MATMUL) +# exp = _register_op(EXP) +# negative = _register_op(NEGATIVE) +# +# sigmoid = _register_op(SIGMOID) +# softmax = _register_op(SOFTMAX) +# +# requant = _register_op(REQUANT) +# pclip = _register_op(PCLIP) +# rs_pclip = _register_op(RS_PCLIP) +# lut = _register_op(LUT) def is_operator(symbol: Symbol, params: ParametersT = {}): return symbol.op_name != VAR diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py new file mode 100644 index 0000000..67eaa04 --- /dev/null +++ b/python/mrt/mir/opclass.py @@ -0,0 +1,861 @@ +import typing +import numpy as np + +from mrt.common.utils import N +from . import opns +from . import symbol + +SymbolCreator = typing.Union[typing.Callable[[typing.Any, ...], typing.Type[symbol.Symbol]], symbol.SelfSymbol] + +MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} + +def _register_op_map(op_name: str): + def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: + if len(op_name) > 0 and clss != None: + if op_name not in MRT_OP_MAP: + MRT_OP_MAP[op_name] = clss + else: + print(f'Warning: "{op_name}" Alreary Registered In MRT_OP_MAP, IsBeing Overrided!') + MRT_OP_MAP[op_name] = clss + return clss + return _wrapper + + +# OPs from external (not in MRT op), using custom op_name with default op_func +# y = extern_opfunc("tanh")(X) +def extern_opfunc(op_name: str): + def op_func(*args, name=None, extra_attrs=None, **kwargs): + return symbol.Symbol(*args, name=name or N.n(), op_name=op_name, extra_attrs=extra_attrs or {}, **kwargs) + return op_func + +def _from_dict_attrs(cls, d: dict, attr_keys:typing.List[str]=[], **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in attr_keys} + try: + out = cls(*data['args'], **basedata, **attrsdata) + except Exception as e: + raise e + return out + +# OPs without attrs, just register function (funcName should be lower case) +def var(name=None, shape=(), dtype=float) -> symbol.Symbol: + return symbol.Symbol(name=name or N.n(), op_name=opns.VAR, extra_attrs={'shape': shape or (), 'dtype': dtype or float}) + +#def _return_func_single_arg(op_name: op_name): +def relu(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.RELU, extra_attrs=extra_attrs or {}) + +def silu(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.SILU, extra_attrs=extra_attrs or {}) + + +class Conv2D(symbol.Symbol): + op_name = opns.CONV2D + + @property + def strides(self) -> typing.Tuple[int, int]: + return self.attrs['strides'] + + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + return self.attrs['padding'] + + @property + def groups(self) -> int: + return self.attrs['groups'] + + @property + def dilation(self) -> typing.Tuple[int, int]: + return self.attrs['dilation'] + + + # Follows (*args, name, **attrs) + def __init__(self, X, W, name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' + kernel_size = (W.shape[2], W.shape[3]) + #attrs = {'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size, 'kernel_layout': kernel_layout} + attrs = {'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation} + super().__init__(X, W, name=name or N.n(), op_name=opns.CONV2D, extra_attrs=extra_attrs or {}, **attrs) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + # Auto inferred 'kernel_size' + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation'], **kwargs) + +def conv2d(*args, **kwargs): + #def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + return Conv2D(*args, **kwargs) #X, W, name, op_name, strides, padding, groups, dilation, kernel_layout, extra_attrs) + + +class Dropout(symbol.Symbol): + op_name = opns.DROP_OUT + + @property + def p(self) -> float: + return self.attrs['p'] + + def __init__(self, X, name=None, p:float = 0.5, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.DROP_OUT, extra_attrs=extra_attrs or {}, **{'p': p}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['p'], **kwargs) + +def dropout(*args, **kwargs): + return dropout(*args, **kwargs) + + +class Clip(symbol.Symbol): + op_name = opns.CLIP + + @property + def min(self) -> float: + assert 'min' in self.attrs + return self.attrs['min'] + + @property + def max(self) -> float: + assert 'max' in self.attrs + return self.attrs['max'] + + def __init__(self, X, name=None, min:float = np.nan, max:float = np.nan, extra_attrs=None): + assert min != np.nan + assert max != np.nan + super().__init__(X, name=name or N.n(), op_name=opns.CLIP, extra_attrs=extra_attrs or {}, **{'min': min, 'max': max}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) + +def clip(*args, **kwargs): + return Clip(*args, **kwargs) + +class BatchNorm(symbol.Symbol): + op_name = opns.BATCH_NORM + + @property + def axis(self) -> int: + return self.attrs['axis'] + + @property + def epsilon(self) -> float: + return self.attrs['epsilon'] + + @property + def momentum(self) -> float: + return self.attrs['momentum'] + + @property + def center(self) -> bool: + return self.attrs['center'] + + @property + def scale(self) -> bool: + return self.attrs['scale'] + + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): + super().__init__(*[X, Gamma, Beta, Mean, Var], name=name or N.n(), op_name=opns.BATCH_NORM, extra_attrs=extra_attrs or {}, **{'axis': axis, 'epsilon': epsilon, 'momentum': momentum, 'center': center, 'scale': scale}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum', 'center', 'scale'], **kwargs) + +def batch_norm(*args, **kwargs): + return BatchNorm(*args, **kwargs) + + +class TupleGetItem(symbol.Symbol): + op_name = opns.TUPLE_GET_ITEM + + @property + def index(self) -> float: + return self.attrs['index'] + + def __init__(self, X, name=None, index:int = 0, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.TUPLE_GET_ITEM, extra_attrs=extra_attrs or {}, **{'index': index}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['index'], **kwargs) + +def tuple_get_item(*args, **kwargs): + return TupleGetItem(*args, **kwargs) + + +class LeakyRelu(symbol.Symbol): + op_name = opns.LEAKY_RELU + + @property + def negative_slope(self) -> float: + return self.attrs['negative_slope'] + + def __init__(self, X, name=None, negative_slope:float = 1e-2, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.LEAKY_RELU, extra_attrs=extra_attrs or {}, **{'negative_slope': negative_slope}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) + +def leaky_relu(*args, **kwargs): + return LeakyRelu(*args, **kwargs) + + +def dense(X, W, B, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, W, B], name=name or N.n(), op_name=opns.DENSE, extra_attrs=extra_attrs or {}) + +class Hardtanh(symbol.Symbol): + op_name = opns.HARDTANH + + @property + def min_val(self) -> float: + return self.attrs['min_val'] + + @property + def max_val(self) -> float: + return self.attrs['max_val'] + + def __init__(self, X, name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.HARDTANH, extra_attrs=extra_attrs or {}, **{'min_val': min_val, 'max_val':max_val}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) + +def hard_tanh(*args, **kwargs): + return Hardtanh(*args, **kwargs) + +class AdaptiveAvgPool2D(symbol.Symbol): + op_name = opns.ADAPTIVE_AVG_POOL2D + + @property + def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: + assert 'output_size' in self.attrs + return self.attrs['output_size'] + + def __init__(self, X, name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=None, extra_attrs=None): + assert output_size != None + super().__init__(X, name=name or N.n(), op_name=opns.ADAPTIVE_AVG_POOL2D, extra_attrs=extra_attrs or {}, **{'output_size': output_size}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['output_size'], **kwargs) + +def adaptive_avg_pool2d(*args, **kwargs): + return AdaptiveAvgPool2D(*args, **kwargs) + +class AvgPool2D(symbol.Symbol): + op_name = opns.AVG_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] + @property + def strides(self) -> typing.Tuple[int, int]: + return self.attrs['strides'] + @property + def dilation(self) -> typing.Tuple[int, int]: + return self.attrs['dilation'] + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + return self.attrs['padding'] + @property + def ceil_mode(self) -> bool: + return self.attrs['ceil_mode'] + @property + def layout(self) -> str: + return self.attrs['layout'] + @property + def count_include_pad(self) -> bool: + return self.attrs['count_include_pad'] + + def __init__(self, X, name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + assert pool_size != None + super().__init__(X, name=name or N.n(), op_name=opns.AVG_POOL2D, extra_attrs=extra_attrs or {}, **{'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + +def avg_pool2d(*args, **kwargs): + return AvgPool2D(*args, **kwargs) + +class MaxPool2D(symbol.Symbol): + op_name = opns.MAX_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] + @property + def strides(self) -> typing.Tuple[int, int]: + return self.attrs['strides'] + @property + def dilation(self) -> typing.Tuple[int, int]: + return self.attrs['dilation'] + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + return self.attrs['padding'] + @property + def ceil_mode(self) -> bool: + return self.attrs['ceil_mode'] + @property + def layout(self) -> str: + return self.attrs['layout'] + + def __init__(self, X, name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): + assert pool_size != None + super().__init__(X, name=name or N.n(), op_name=opns.MAX_POOL2D, extra_attrs=extra_attrs or {}, **{'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout'], **kwargs) + +def max_pool2d(*args, **kwargs): + return MaxPool2D(*args, **kwargs) + + +class Softmax(symbol.Symbol): + op_name = opns.SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + return self.attrs['axis'] + + def __init__(self, X, name=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.SOFTMAX, extra_attrs=extra_attrs or {}, **{'axis':axis}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + +def softmax(*args, **kwargs): + return Softmax(*args, **kwargs) + +class LogSoftmax(symbol.Symbol): + op_name = opns.LOG_SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + return self.attrs['axis'] + + def __init__(self, X, name=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.LOG_SOFTMAX, extra_attrs=extra_attrs or {}, **{'axis':axis}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + +def log_softmax(*args, **kwargs): + return LogSoftmax(*args, **kwargs) + + +def exp(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.EXP, extra_attrs=extra_attrs or {}) + +def sigmoid(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.SIGMOID, extra_attrs=extra_attrs or {}) + +class Sum(symbol.Symbol): + op_name = opns.SUM + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + return self.attrs['dim'] + + @property + def keepdim(self) -> typing.Optional[bool]: + return self.attrs['keepdim'] + + def __init__(self, X, name=None, dim=None, keepdim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.SUM, extra_attrs=extra_attrs or {}, **{'dim': dim, 'keepdim': keepdim}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def sum(*args, **kwargs): + return Sum(*args, **kwargs) + + +class Mean(symbol.Symbol): + op_name = opns.MEAN + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + return self.attrs['dim'] + + @property + def keepdim(self) -> typing.Optional[bool]: + return self.attrs['keepdim'] + + def __init__(self, X, name=None, dim=None, keepdim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.MEAN, extra_attrs=extra_attrs or {}, **{'dim': dim, 'keepdim': keepdim}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def mean(*args, **kwargs): + return Mean(*args, **kwargs) + + +class MaxAxis(symbol.Symbol): + op_name = opns.MAX_AXIS + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + return self.attrs['dim'] + + @property + def keepdim(self) -> typing.Optional[bool]: + return self.attrs['keepdim'] + + def __init__(self, X, name=None, dim=None, keepdim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.MAX_AXIS, extra_attrs=extra_attrs or {}, **{'dim': dim, 'keepdim': keepdim}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def max_axis(*args, **kwargs): + return MaxAxis(*args, **kwargs) + + +def maximum(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.MAXIMUM, extra_attrs=extra_attrs or {}) + +def minimum(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.MINIMUM, extra_attrs=extra_attrs or {}) + +#def repeat(X, name=None, extra_attrs=None) -> symbol.Symbol: +# return symbol.Symbol(X, name=name or N.n(), op_name=opns.REPEAT, extra_attrs=extra_attrs or {}) +class Repeat(symbol.Symbol): + op_name = opns.REPEAT + + @property + def repeats(self) -> typing.Optional[int]: + return self.attrs['repeats'] + + @property + def axis(self) -> typing.Optional[int]: + return self.attrs['axis'] + + def __init__(self, X, name=None, repeats=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.REPEAT, extra_attrs=extra_attrs or {}, **{'repeats': repeats, 'axis': axis}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['repeats', 'axis'], **kwargs) +def repeat(*args, **kwargs): + return Repeat(*args, **kwargs) + +class Squeeze(symbol.Symbol): + op_name = opns.SQUEEZE + + @property + def dim(self) -> typing.Optional[int]: + return self.attrs['dim'] + + def __init__(self, X, name=None, dim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.SQUEEZE, extra_attrs=extra_attrs or {}, **{'dim': dim}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim'], **kwargs) + +def squeeze(*args, **kwargs): + return Squeeze(*args, **kwargs) + +class Flatten(symbol.Symbol): + op_name = opns.FLATTEN + + @property + def start_dim(self) -> int: + return self.attrs['start_dim'] + + @property + def end_dim(self) -> int: + return self.attrs['end_dim'] + + def __init__(self, X, name=None, start_dim=0, end_dim=-1, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.FLATTEN, extra_attrs=extra_attrs or {}, **{'start_dim': start_dim, 'end_dim':end_dim}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) + +def flatten(*args, **kwargs): + return Flatten(*args, **kwargs) + + +class Reshape(symbol.Symbol): + op_name = opns.RESHAPE + + @property + def newshape(self) -> typing.Tuple[int,...]: + assert 'newshape' in self.attrs + return self.attrs['newshape'] + + def __init__(self, X, name=None, newshape=None, extra_attrs=None): + assert newshape != None + super().__init__(X, name=name or N.n(), op_name=opns.RESHAPE, extra_attrs=extra_attrs or {}, **{'newshape': newshape}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + +def reshape(*args, **kwargs): + return Reshape(*args, **kwargs) + +class Concat(symbol.Symbol): + op_name = opns.CONCAT + + @property + def axis(self) -> int: + return self.attrs['axis'] + + def __init__(self, X, name=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.CONCAT, extra_attrs=extra_attrs or {}, **{'axis': axis}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + +def concat(*args, **kwargs): + return Concat(*args, **kwargs) + +class Split(symbol.Symbol): + op_name = opns.SPLIT + + @property + def split_size(self) -> typing.List[int]: + assert 'split_size' in self.attrs + return self.attrs['split_size'] + + @property + def dim(self) -> int: + return self.attrs['dim'] + + def __init__(self, X, name=None, split_size=None, dim=0, extra_attrs=None): + assert split_size != None + super().__init__(X, name=name or N.n(), op_name=opns.SPLIT, extra_attrs=extra_attrs or {}, **{'split_size': split_size, 'dim': dim}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) + +def split(*args, **kwargs): + return Split(*args, **kwargs) + + +class Transpose(symbol.Symbol): + op_name = opns.TRANSPOSE + + @property + def dim0(self) -> int: + assert 'dim0' in self.attrs + return self.attrs['dim0'] + + @property + def dim1(self) -> int: + assert 'dim1' in self.attrs + return self.attrs['dim1'] + + def __init__(self, X, name=None, dim0=None, dim1=None, extra_attrs=None): + assert dim0 != None + assert dim1 != None + super().__init__(X, name=name or N.n(), op_name=opns.TRANSPOSE, extra_attrs=extra_attrs or {}, **{'dim0': dim0, 'dim1': dim1}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) + +def transpose(*args, **kwargs): + return Transpose(*args, **kwargs) + + +class BroadcastTo(symbol.Symbol): + op_name = opns.BROADCAST_TO + + @property + def newshape(self) -> typing.Tuple[int,...]: + assert 'newshape' in self.attrs + return self.attrs['newshape'] + + def __init__(self, X, name=None, newshape=None, extra_attrs=None): + assert newshape != None + super().__init__(X, name=name or N.n(), op_name=opns.BROADCAST_TO, extra_attrs=extra_attrs or {}, **{'newshape': newshape}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + +def broadcast_to(*args, **kwargs): + return BroadcastTo(*args, **kwargs) + +class ExpandDims(symbol.Symbol): + op_name = opns.EXPAND_DIMS + + @property + def newshape(self) -> typing.Tuple[int,...]: + assert 'newshape' in self.attrs + return self.attrs['newshape'] + + def __init__(self, X, name=None, newshape=None, extra_attrs=None): + assert newshape != None + super().__init__(X, name=name or N.n(), op_name=opns.EXPAND_DIMS, extra_attrs=extra_attrs or {}, **{'newshape': newshape}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + +def expand_dims(*args, **kwargs): + return ExpandDims(*args, **kwargs) + +class Tile(symbol.Symbol): + op_name = opns.TILE + + @property + def dims(self) -> typing.Tuple[int,...]: + assert 'dims' in self.attrs + return self.attrs['dims'] + + def __init__(self, X, name=None, dims=None, extra_attrs=None): + assert dims != None + super().__init__(X, name=name or N.n(), op_name=opns.TILE, extra_attrs=extra_attrs or {}, **{'dims': dims}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + +def tile(*args, **kwargs): + return Tile(*args, **kwargs) + + +def where(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.WHERE, extra_attrs=extra_attrs or {}) + +def greater(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X,Y], name=name or N.n(), op_name=opns.GREATER, extra_attrs=extra_attrs or {}) + +class NonMaxSuppression(symbol.Symbol): + op_name = opns.NON_MAX_SUPRESSION + + @property + def iou_threshold(self) -> float: + return self.attrs['iou_threshold'] + @property + def score_threshold(self) -> typing.Optional[float]: + return self.attrs['score_threshold'] + + def __init__(self, X, name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.NON_MAX_SUPRESSION, extra_attrs=extra_attrs or {}, **{'iou_threshold': iou_threshold,'score_threshold':score_threshold}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + +def non_max_suppression(*args, **kwargs): + return NonMaxSuppression(*args, **kwargs) + + +def ceil(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.CEIL, extra_attrs=extra_attrs or {}) + +def right_shift(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.RIGHT_SHIFT, extra_attrs=extra_attrs or {}) + +class Add(symbol.Symbol): + op_name = opns.ADD + + @property + def alpha(self) -> int: + return self.attrs['alpha'] + + def __init__(self, X, Y, name=None, alpha=1, extra_attrs=None): + super().__init__(*[X, Y], name=name or N.n(), op_name=opns.ADD, extra_attrs=extra_attrs or {}, **{'alpha': alpha}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +def add(*args, **kwargs): + return Add(*args, **kwargs) + +class Sub(symbol.Symbol): + op_name = opns.SUB + + @property + def alpha(self) -> int: + return self.attrs['alpha'] + + def __init__(self, X, Y, name=None, alpha=1, extra_attrs=None): + super().__init__(*[X, Y], name=name or N.n(), op_name=opns.SUB, extra_attrs=extra_attrs or {}, **{'alpha': alpha}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +def sub(*args, **kwargs): + return Sub(*args, **kwargs) + +def mul(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.MUL, extra_attrs=extra_attrs or {}) + +def mat_mul(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.MATMUL, extra_attrs=extra_attrs or {}) + +class Div(symbol.Symbol): + op_name = opns.DIV + + @property + def rounding_mode(self) -> typing.Optional[str]: + return self.attrs['rounding_mode'] + + def __init__(self, X, Y, name=None, rounding_mode=None, extra_attrs=None): + super().__init__(*[X, Y], name=name or N.n(), op_name=opns.DIV, extra_attrs=extra_attrs or {}, **{'rounding_mode': rounding_mode}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) + +def div(*args, **kwargs): + return Div(*args, **kwargs) + +def negative(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.NEGATIVE, extra_attrs=extra_attrs or {}) + +def abs(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.ABS, extra_attrs=extra_attrs or {}) + +def log(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.LOG, extra_attrs=extra_attrs or {}) + +def sqrt(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.SQRT, extra_attrs=extra_attrs or {}) + +def pow(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.POW, extra_attrs=extra_attrs or {}) + +def identity(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.IDENTITY, extra_attrs=extra_attrs or {}) + +class Arange(symbol.Symbol): + op_name = opns.ARANGE + + @property + def end(self) -> int: + assert 'end' in self.attrs + return self.attrs['end'] + + @property + def start(self) -> int: + return self.attrs['start'] + + @property + def step(self) -> int: + return self.attrs['step'] + + def __init__(self, name=None, end=None, start=0, step=1, extra_attrs=None): + assert end != None + super().__init__(name=name or N.n(), op_name=opns.ARANGE, extra_attrs=extra_attrs or {}, **{'end': end, 'start': start, 'step': step}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) + +def arange(*args, **kwargs): + return Arange(*args, **kwargs) + + +def zeros_like(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.ZEROS_LIKE, extra_attrs=extra_attrs or {}) + +def ones_like(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.ONES_LIKE, extra_attrs=extra_attrs or {}) + + +_register_op_map(opns.VAR)(var) +_register_op_map(opns.RELU)(relu) + +_register_op_map(opns.CONV2D)(Conv2D) +_register_op_map(opns.DROP_OUT)(Dropout) +_register_op_map(opns.CLIP)(Clip) +_register_op_map(opns.BATCH_NORM)(BatchNorm) +_register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) + +_register_op_map(opns.LEAKY_RELU)(LeakyRelu) + +_register_op_map(opns.MUL)(mul) +_register_op_map(opns.DENSE)(dense) +_register_op_map(opns.HARDTANH)(Hardtanh) +_register_op_map(opns.SILU)(silu) +_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(AdaptiveAvgPool2D) +_register_op_map(opns.AVG_POOL2D)(AvgPool2D) +_register_op_map(opns.MAX_POOL2D)(MaxPool2D) +_register_op_map(opns.SOFTMAX)(Softmax) +_register_op_map(opns.LOG_SOFTMAX)(LogSoftmax) +_register_op_map(opns.EXP)(exp) +_register_op_map(opns.SIGMOID)(sigmoid) +_register_op_map(opns.SUM)(Sum) +_register_op_map(opns.MEAN)(Mean) +_register_op_map(opns.MAX_AXIS)(MaxAxis) +_register_op_map(opns.MAXIMUM)(maximum) +_register_op_map(opns.MINIMUM)(minimum) + + +_register_op_map(opns.REPEAT)(repeat) +_register_op_map(opns.SQUEEZE)(Squeeze) +_register_op_map(opns.FLATTEN)(Flatten) +_register_op_map(opns.RESHAPE)(Reshape) +_register_op_map(opns.CONCAT)(Concat) +_register_op_map(opns.SPLIT)(Split) +_register_op_map(opns.TRANSPOSE)(Transpose) +_register_op_map(opns.BROADCAST_TO)(BroadcastTo) +_register_op_map(opns.EXPAND_DIMS)(ExpandDims) +_register_op_map(opns.TILE)(Tile) +_register_op_map(opns.WHERE)(where) +_register_op_map(opns.GREATER)(greater) +_register_op_map(opns.NON_MAX_SUPRESSION)(NonMaxSuppression) + +_register_op_map(opns.CEIL)(ceil) +_register_op_map(opns.RIGHT_SHIFT)(right_shift) + +_register_op_map(opns.ADD)(Add) +_register_op_map(opns.SUB)(Sub) +_register_op_map(opns.MATMUL)(mat_mul) +_register_op_map(opns.DIV)(Div) +_register_op_map(opns.NEGATIVE)(negative) +_register_op_map(opns.ABS)(abs) +_register_op_map(opns.LOG)(log) +_register_op_map(opns.SQRT)(sqrt) +_register_op_map(opns.POW)(pow) +_register_op_map(opns.IDENTITY)(identity) +_register_op_map(opns.ARANGE)(Arange) +_register_op_map(opns.ZEROS_LIKE)(zeros_like) +_register_op_map(opns.ONES_LIKE)(ones_like) + + +# Add default register Class for MRT OP Not Implemented! +_register_op_map(opns.TUPLE)(extern_opfunc(opns.TUPLE)) +_register_op_map(opns.AS_TYPE)(extern_opfunc(opns.AS_TYPE)) +_register_op_map(opns.ADV_INDEX)(extern_opfunc(opns.ADV_INDEX)) +_register_op_map(opns.CALL_TIR)(extern_opfunc(opns.CALL_TIR)) +_register_op_map(opns.CALL_DPS_PACKED)(extern_opfunc(opns.CALL_DPS_PACKED)) + +_register_op_map(opns.IF)(extern_opfunc(opns.IF)) +_register_op_map(opns.ARGWHERE)(extern_opfunc(opns.ARGWHERE)) +_register_op_map(opns.REQUANT)(extern_opfunc(opns.REQUANT)) +_register_op_map(opns.PCLIP)(extern_opfunc(opns.PCLIP)) +_register_op_map(opns.RS_PCLIP)(extern_opfunc(opns.RS_PCLIP)) +_register_op_map(opns.LUT)(extern_opfunc(opns.LUT)) + +_register_op_map(opns.BATCH_FLATTEN)(extern_opfunc(opns.BATCH_FLATTEN)) +_register_op_map(opns.STRIDED_SLICE)(extern_opfunc(opns.STRIDED_SLICE)) +_register_op_map(opns.SLICE_LIKE)(extern_opfunc(opns.SLICE_LIKE)) +_register_op_map(opns.GET_VALID_COUNT)(extern_opfunc(opns.GET_VALID_COUNT)) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index 5b92822..cec6427 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -80,7 +80,7 @@ SQRT = "sqrt" POW = "pow" -PASS = "pass" +IDENTITY = "identity" # original PASS # ======= auto generate op ========= ARANGE = "arange" ZEROS_LIKE = "zeros_like" @@ -99,3 +99,6 @@ LUT = "mrt.lut" """ look up table, equals adv_index in tvm """ + +def Opname2Funcname(op_name: str) -> str: + return op_name.replace('.', '_') diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 5c97cee..cadb48e 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -11,8 +11,7 @@ # from . import config # from .utils import * -# from .types import * -from .opns import * +from . import opns __ALL__ = [ "Symbol", @@ -20,6 +19,8 @@ "filter_operators", ] +SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") + def _format_printer(data): if isinstance(data, dict): data = ["{}={}".format(k, _format_printer(v)) \ @@ -111,8 +112,12 @@ def like(self, other: Symbol, **kwargs) -> Symbol: """ cast current symbol to child class. """ # assert self.shape == other.shape, "%s vs.\n %s" % (self, other) # assert self.dtype == other.dtype , "%s vs.\n %s" % (self, other) + assert isinstance(other, Symbol) data = other.to_dict() - data.update(self.to_dict()) + data_new = self.to_dict() + data.update(data_new) + + data["extra_attrs"] = other.extra_attrs if self.extra_attrs == {} else data["extra_attrs"] # copy extra attrs by default. # data["extra_attrs"] = other.extra_attrs return type(other).from_dict(data, **kwargs) @@ -187,7 +192,6 @@ def _uniform(n: str, max_size: int) -> str: _format_printer(oattrs)) -@dataclass class Symbol(_BaseSymbol): """ Uniform Symbol Representation for RelayExpr @@ -204,6 +208,15 @@ class Symbol(_BaseSymbol): for the user's config about quantization layers. """ + def __init__(self, *args, name:str=None, op_name:str=None, extra_attrs:dict=None, **attrs): + assert name != None + assert op_name != None + self.name = name + self.op_name = op_name + self.args = [arg for arg in args] + self.attrs = attrs + self.extra_attrs = extra_attrs or {} + # Overridable Methods, inheritted from _BaseSymbol # to support multi-inherit design. @classmethod @@ -215,12 +228,43 @@ def set_extra_attrs(self, **kwargs): def base(cls, symbol: Symbol, **kwargs) -> Symbol: return super().base(symbol, **kwargs) def like(self, other: Symbol, **kwargs) -> Symbol: - return super().like(other, **kwargs) + """ cast current symbol to child class. """ + assert isinstance(other, Symbol) + data = other.to_dict() + data_new = self.to_dict() + data.update(data_new) + data["extra_attrs"] = other.extra_attrs if self.extra_attrs == {} else data["extra_attrs"] + # copy extra attrs by default. + # data["extra_attrs"] = other.extra_attrs + #return type(other).from_dict(data, **kwargs) + return Symbol.from_dict(data, **kwargs) + def as_variable(self, **kwargs) -> Symbol: + sym = Symbol.from_dict(self.to_dict(), **kwargs) # kwargs override self + sym.op_name = opns.VAR + sym.args = [] + sym.attrs = {} + return sym def copy(self, **kwargs) -> Symbol: return super().copy(**kwargs) @classmethod def from_dict(cls, d: dict, **kwargs): - return super().from_dict(d, **kwargs) + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + fnames = [f.name for f in fields(cls)] + data = {k: data[k] for k in data if k in fnames} + args = data['args'] or [] + attrs = data['attrs'] or {} + try: + out = cls(*args, name=data['name'], op_name=data['op_name'], extra_attrs=data['extra_attrs'], **attrs) + except Exception as _: + raise RuntimeError(( + "Error for type:{} create from dict, " + "expected: {}, but get {}" + ).format(get_class_name(cls), + fnames, data.keys())) + return out @classmethod def default_dict(cls, **kwargs) -> dict: kwargs.setdefault("args", []) @@ -277,34 +321,6 @@ def __hash__(self) -> int: def hash(self) -> int: return hash(str(self)) -# class Convolution2D(Symbol): -# strides: typing.Tuple[int, int] - -# class Dropout(Symbol): -# eps: float = 1e-5 - -# class Pass: -# symbol: Symbol - -# def visit(self, op: Symbol): -# env: typing.Dict[Symbol, Symbol] = {} -# for sym in sym2list(self.symbol): -# out = getattr(self, f"visit_{op.op_name}")(op) or op -# assert isinstance(sym, Symbol) -# env[sym] = out -# return env[op] - -# def _default_visit_op(op): -# return op - -# for op in op_list: -# setattr(Pass, f"visit_{op.op_name}", _default_visit_op) - -# class FuseDropoutPass(Pass): -# def visit_dropout(self, op: Dropout): -# op.eps -# return op.args[0] - def _topo_sort(symbol: Symbol, sym_list: typing.List[Symbol]): assert isinstance(symbol, Symbol), \ f"({type(symbol).__name__}){str(symbol)}" @@ -349,6 +365,7 @@ def load_json(data: _SymbolJsonT, **extra_attrs) -> Symbol: _VisitorT = typing.Callable[[Symbol], None] _TransformerT = typing.Callable[[Symbol], typing.Optional[Symbol]] +_TransformerParamT = typing.Callable[[Symbol, typing.Optional[ParametersT]], Symbol] """ Symbol Transformer Return new symbol to transform old symbol into updated one, @@ -371,6 +388,7 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: Only the return value indicates mutation, while changing attributes in parameter passed in args does nothing. """ + assert isinstance(symbol.args, list), f"Symbol_Args_Wrong_type: {type(symbol.args)}" sym_map: typing.Dict = {} C = config.LogConfig.G() for sym in sym2list(symbol): @@ -382,7 +400,11 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f"<< {sym}") - out = callback(sym) or sym + # Skipping transform output symbol in trace-Exporter + if callback.__name__.find("Exporter")>=0 and sym.name == symbol.name: + out = sym + else: + out = callback(sym) or sym assert isinstance(out, Symbol), out # default const_ prefix symbol means parameters assert sym.name not in sym_map, sym.name @@ -474,27 +496,6 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: # name: str = "main") -> MultiHeadSymbol: # return MultiHeadSymbol(**{ name: symbol }) -class MultiHeadSymbol(dict): - """ { "main": F(X) } """ - origin: typing.Optional[Symbol] = None - - @classmethod - def from_symbol(cls, symbol: Symbol, name: str = "main"): - return MultiHeadSymbol({ name: symbol }) - - def as_tuple(self) -> typing.Tuple[typing.List[str], Symbol]: - from . import op - # args = list(self.values()) - # sym_type = type(args[0]) if args else Symbol - mhs = self.origin or op.Tuple(*list(self.values())) - return list(self.keys()), mhs - - @classmethod - def from_tuple(cls, tuple_names, symbol): - assert symbol.is_op(TUPLE), symbol - mhs = cls(zip(tuple_names, symbol.args)) - mhs.origin = symbol - return mhs # MultiHeadSymbol = typing.Dict[str, Symbol] @@ -545,11 +546,6 @@ def from_tuple(cls, tuple_names, symbol): # return {k: load_json(v) for k, v in data} # ^^^^^^^^^^^^^^^ MultiHeadSymbol API ^^^^^^^^^^^^^^^^^^ - -Graph = typing.Union[Symbol, MultiHeadSymbol] -""" Notice that Symbol and MultiHeadSymbol can both - be regarded as a model Graph. -""" # def graph_visit(graph: Graph, callback: _VisitorT): # return visit(graph, callback) # # visit_func = visit if isinstance(graph, Symbol) else mhs_visit diff --git a/python/mrt/mir/symbol_pass.py b/python/mrt/mir/symbol_pass.py new file mode 100644 index 0000000..49b349d --- /dev/null +++ b/python/mrt/mir/symbol_pass.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import typing +from functools import wraps +from dataclasses import field + +import numpy as np + +from mrt.mir.symbol import * +from mrt.mir.mhsymbol import Graph + +from mrt.mir import op, opns, opclass +from mrt.mir.attrs import _BaseAttrs, parse_attrs + +from mrt.common.utils import N + + +class SymbolBridge: # SymbolManipulator / Pass + graph: Symbol + + def __init__(self, symbol: Symbol): + self.graph = symbol + + @classmethod + def base(cls, symbol: Symbol): + return cls(symbol) + + def __repr__(self, **attrs): + return self.graph.__repr__(**attrs) + + def from_symbol(self, sym: Symbol) -> typing.Self: + return type(self)(sym) + + @property + def parsed(self)-> _BaseAttrs: + return parse_attrs(self.graph.op_name, self.graph.attrs) + return self.graph.attrs + + """Member Symbol Start + """ + def is_op(self, *op_names) -> bool: + """ Check current symbol is in the op name list. """ + assert len(op_names) > 0 + return self.graph.op_name in op_names + def is_near(self, *names, check_args: bool = True) -> bool: + return self.graph.is_near(*names, check_args) + def to_dict(self) -> dict: + return self.graph.to_dict() + @classmethod + def from_dict(cls, d: dict, **kwargs) -> SymbolParameters: + return cls(Symbol.from_dict(d, **kwargs), {}) + @property + def args(self) -> list: + return self.graph.args + @property + def op_name(self) -> str: + return self.graph.op_name + @property + def name(self) -> str: + return self.graph.name + @property + def shape(self) -> typing.Optional[ShapeT]: + return self.graph.shape + @property + def dtype(self) -> str: + return self.graph.dtype + @property + def attrs(self) -> dict: + return self.graph.attrs + @property + def extra_attrs(self) -> dict: + return self.graph.extra_attrs + def set_extra_attrs(self, **kwargs): + return self.graph.extra_attrs.update(kwargs) + """Member Symbol End + """ + +class SymbolParameters(SymbolBridge): + graph: Symbol + params: ParametersT = field(repr=False) + """ Parameters should not be changed in transformer, + use copy mode instead to avoid possible errors. + + deep copy params in trace `checkpoint_run` api. + """ + + def __init__(self, symbol: Symbol, params: ParametersT): + self.graph = symbol + self.params = params + + @classmethod + def base(cls, symbol: Symbol, params: ParametersT): + return cls(symbol, params) + + def __repr__(self, **attrs): + if self.is_param(): + attrs["absmax"] = np.abs(self.numpy()).max(initial=0) + return super().__repr__(**attrs) + + @property + def parsed(self)-> _BaseAttrs: + return parse_attrs(self.graph.op_name, self.graph.attrs) + attrs = self.graph.attrs + return attrs + + + def numpy(self) -> OpNumpyT: + assert self.is_param(), f"{self.graph.name} is not parameter." + data = self.params[self.graph.name] + assert isinstance(data, (tuple, list, np.ndarray)), \ + f"param:{self.graph.name} not OpNumpyT, get {type(data)}" + return data + + def as_parameter(self, data: OpNumpyT) -> Symbol: + def _f(data, dtype): + if isinstance(data, list): + assert len(data) == len(dtype) + return [_f(d, t) for d, t in zip(data, dtype)] + assert isinstance(data, np.ndarray), type(data) + return data.astype(dtype) + + self.params[self.graph.name] = _f(data, self.graph.dtype) + return op.as_variable(self.graph) + + def from_const_data(self, data: typing.Union[int, float]) -> Symbol: + return self.from_np_data(data) + + def from_symbol(self, sym: Symbol) -> typing.Type[SymbolParameters]: + return type(self)(sym, self.params) + + def from_np_data(self, data: np.ndarray | typing.Union[int, float], prefix="%") -> Symbol: + """ out = Return Symbol + out = op.add(out, B) + self: SymbolParameter + self.graph: Symbol + self.from_symbol(out).from_np_data() + + out = Return Symbol + out.from_np_data() + + op.add(out.graph, B) + + graph: Symbol + """ + name = N.n(prefix=prefix) + # some data is np.float/int type, use np.array to wrap it. + data = np.array(data) + self.params[name] = data.astype(self.graph.dtype) + ## return type(self). # Mark! + return opclass.var(name, data.shape, self.graph.dtype).like(self.graph) + + def is_input(self) -> bool: + return op.is_input(self.graph, self.params) + def is_param(self) -> bool: + return op.is_param(self.graph, self.params) + def is_variable(self) -> bool: + return op.is_variable(self.graph, self.params) + def is_operator(self) -> bool: + return op.is_operator(self.graph, self.params) + +SymTransformerT = typing.Callable[[Graph], Graph] +""" Symbol-Transformer Callback Function Type, + inherited from SymbolParameters. +""" + +class SymbolTransformer(SymbolParameters): + """ Symbol Transformer(Manipulator) """ + + RUN_ONCE: typing.ClassVar[bool] = False + + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + + @classmethod + def get_transformer(cls, name: typing.Optional[str] = None): + name = name or cls.__name__ + def _func(graph: Symbol, params: ParametersT, **kwargs): + def _run(sym: Symbol): + # use current cls to apply transform, this + # may loss some information from origin + # symbol, so record as `origin` in call. + out = cls.base(sym, params) # Type as SymbolTransformer + out = out(origin=sym, **kwargs) or sym # Type as Symbol + assert isinstance(out, Symbol), ( + "transform output type should be {}," + " but get {}" + ).format(cls, type(out)) + return out + _run.__name__ = name + with N(name): + return _run(graph) if cls.RUN_ONCE \ + else transform(graph, _run) + _func.__name__ = name + return _func + + # @classmethod + # def apply(cls, *args, **kw): + # """ Static apply function to generator transformer pass. + + # All the parameters are used to invoke `call` method. + # """ + # def _tfm(sym: Symbol, params: ParametersT): + # ins = cls.base(sym, params=params) + # out = ins(*args, **kw) or ins + # assert isinstance(out, cls), ( + # "expected {}, but get {}" + # ).format(cls, type(out)) + # return out + + # _tfm.__name__ = cls.__name__ + # return _tfm + + def __call__(self, *args, **kw) -> typing.Optional[SymbolTransformer]: + """ + Parameters: + origin: original symbol passed from last transformer. + """ + raise NotImplementedError() + +class RunOnce(SymbolTransformer): + RUN_ONCE: typing.ClassVar[bool] = True + + def __init__(self, *args): # symbol: Symbol, params: ParametersT):#, parsed: _BaseAttrs=None): + super().__init__(*args) + diff --git a/python/mrt/quantization/calibrate.py b/python/mrt/quantization/calibrate.py index 0cf8a0e..278e3c1 100644 --- a/python/mrt/quantization/calibrate.py +++ b/python/mrt/quantization/calibrate.py @@ -4,23 +4,35 @@ import numpy as np -from dataclasses import dataclass, field, InitVar +from dataclasses import field, InitVar from mrt.mir import op, opns from mrt.mir.symbol import * from mrt.runtime import inference -from .transform import Transformer +from mrt.mir.symbol_pass import SymbolTransformer SamplingFuncT = typing.Callable[ [typing.Union[OpNumpyT, float]], typing.Any] -@dataclass(repr=False) -class Calibrator(Transformer): - """ skip dump, and restore from np_data. """ - raw_data: OpOutputT | None = field(repr=False, default=None) - """ calibrate may be processed multi-times """ - data: typing.List[OpNumpyT] = field(default_factory=list) +class Calibrator(SymbolTransformer): + @property + def raw_data(self) -> OpOutputT | None: + return self.extra_attrs.get("raw_data", None) + @raw_data.setter + def raw_data(self, val): + self.set_extra_attrs(raw_data=val) + + @property + def data(self) -> typing.List[OpNumpyT]: + return self.extra_attrs.get("data", []) + @data.setter + def data(self, val): + self.set_extra_attrs(data=val) + + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) def _rand_data(self, enabled: bool = False, @@ -51,10 +63,10 @@ def __call__(self, elif self.is_param(): out = self.params[self.name] else: - single_op = op.retrieve_operator(self) + single_op = op.retrieve_operator(self.graph) out = inference.run_single( single_op, - [a.raw_data for a in self.args], + [self.from_symbol(a).raw_data for a in self.args], **kwargs) assert isinstance(out, (np.ndarray, list)), type(out) @@ -90,8 +102,7 @@ def _assert(self, val, expect): assert val == expect, "{} vs. {}".format(val, expect) -@dataclass(repr=False) -class Sampling(Transformer): +class Sampling(SymbolTransformer): @property def data(self) -> typing.Any: return self.extra_attrs.get("data", None) @@ -99,24 +110,31 @@ def data(self) -> typing.Any: def data(self, val): self.set_extra_attrs(data=val) + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def sampling(cls, np_data: np.ndarray) -> typing.Any: raise NotImplementedError() - def __call__(self, origin: Calibrator, **kw): + def __call__(self, origin: Symbol, **kw): print(type(origin), origin) if self.is_op(opns.CLIP): # TODO: remove clip if threshold is less than a_max a_min, a_max = self.parsed.a_min, self.parsed.a_max self.data = max(abs(a_min), abs(a_max)) else: - self.data = self.sampling(origin.data) - return self + self.data = self.sampling(origin.extra_attrs.get("raw_data")) + return self.graph -@dataclass(repr=False) class SymmetricMinMaxSampling(Sampling): threshold: typing.ClassVar[float] = 1e-5 + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def sampling(cls, data: typing.List[OpNumpyT]) -> float: if isinstance(data, list): diff --git a/python/mrt/quantization/discrete.py b/python/mrt/quantization/discrete.py index bcca5a2..a2402b9 100644 --- a/python/mrt/quantization/discrete.py +++ b/python/mrt/quantization/discrete.py @@ -4,7 +4,10 @@ import math from dataclasses import dataclass, field -from mrt.mir import op +from mrt.mir import op, opclass +from mrt.mir.optype import infer_single +from mrt.mir.opclass import MRT_OP_MAP + from mrt.mir.opns import * from mrt.mir.symbol import * @@ -14,7 +17,6 @@ from .scaler import * from .calibrate import Sampling -from .transform import Transformer from .precision import WithPrecision __ALL__ = [ @@ -33,9 +35,16 @@ def undefined(self) -> bool: return self.scale is None and self.precision is None -@dataclass(repr=False) class QuantInfo(WithScale, WithPrecision, Sampling): - requant_ops: typing.Dict[DiscreteInfo, Symbol] = field(repr=False) + requant_ops: typing.Dict[DiscreteInfo, Symbol] = {} #field(default_factory=dict) + + # inherit SymbolParameters __init__ + def __init__(self, *args): + self.requant_ops = {} + super().__init__(*args) + + def from_symbol(self, sym: Symbol) -> typing.Self: + return type(self)(sym, self.params) @classmethod def default_dict(cls, **kwargs) -> dict: @@ -62,7 +71,7 @@ def rescale(self, info: DiscreteInfo): """ scale, precision = info.scale, info.precision if info.undefined: - return self + return self.graph elif scale is not None: precision = self.scale_to_precision(scale) elif precision is not None: @@ -71,12 +80,11 @@ def rescale(self, info: DiscreteInfo): if info not in self.requant_ops: curr_scale = self.scale if self.scale_defined else 1 #TODO: add pass to check rescale=1 and duplicate requant - out = op.requant( - self, + out = infer_single(MRT_OP_MAP[REQUANT]( + self.graph, rescale=scale/curr_scale, - precision=precision, - ) - out = out.like(self) + precision=precision) + ).like(self.graph) out.set_extra_attrs( data=self.data, scale=scale, precision=precision) self.requant_ops[info] = out @@ -137,12 +145,13 @@ def _rule(s: QuantInfo): register_rules_with_default(SUM, requant_rule=args_max_prec(10)) def uniform_args_scale(args: typing.List[QuantInfo], + params: ParametersT = {}, std_prec: int =15): # standard max precision for add/sub children. assert len(args) > 0 # raw_print(s) - assert any([c.is_operator() for c in args]), \ + assert any([op.is_operator(c.graph, params) for c in args]), \ "Need fuse constant for uniform_args_scale" scales = [] for arg in args: @@ -173,27 +182,28 @@ def uniform_args_scale(args: typing.List[QuantInfo], # scale = min(scaleA, scaleB) # return [DiscreteInfo(scale=scale) for c in s.args] def scale_like_index(s: WithScale, index: int = 0): - return s.args[index].scale + return s.args[index].extra_attrs.get("scale", -1) + register_rules_with_default( ADD, SUB, # BIAS_ADD, MAXIMUM, MINIMUM, - requant_rule=lambda s: uniform_args_scale(s.args), + requant_rule=lambda s: uniform_args_scale([s.from_symbol(a) for a in s.args], s.params), scale_rule=scale_like_index) def scale_concat(s: WithScale): - fscale = s.args[0].scale - if all([a.scale == fscale for a in s.args]): + fscale = s.args[0].extra_attrs.get("scale", -1) + if all([a.extra_attrs.get("scale", -1) == fscale for a in s.args]): return fscale - return [a.scale for a in s.args] + return [a.extra_attrs.get("scale", -1) for a in s.args] register_rules_with_default( CONCAT, TUPLE, - requant_rule=lambda s: uniform_args_scale(s.args), + requant_rule=lambda s: uniform_args_scale([s.from_symbol(a) for a in s.args], s.params), scale_rule=scale_concat) def uniform_first_scale(s: QuantInfo): - target_scale = s.args[0].scale + target_scale = s.args[0].extra_attrs.get("scale", -1) return [DiscreteInfo(scale=target_scale) for c in s.args] register_rules_with_default( @@ -202,7 +212,7 @@ def uniform_first_scale(s: QuantInfo): # register_rules_with_default( # WHERE, -# requant_rule=lambda s: uniform_args_scale(s.args[1:]), +# requant_rule=lambda s: uniform_args_scale([s.from_symbol(a) for a in s.args[1:]], s.params), # scale_rule=scale_like_index(s, -1), # ) @@ -217,7 +227,7 @@ def uniform_first_scale(s: QuantInfo): register_rules_with_default(NEGATIVE) def scale_tuple_get_item(s: WithScale): - ascale = s.args[0].scale + ascale = s.args[0].extra_attrs.get("scale", -1) if isinstance(ascale, (list, tuple)): return ascale[s.parsed.index] return ascale @@ -226,7 +236,7 @@ def scale_tuple_get_item(s: WithScale): scale_rule=scale_tuple_get_item) def op_clip_rules(s: QuantInfo): - scale = s.args[0].scale + scale = s.args[0].extra_attrs.get("scale", -1) s.set_extra_attrs( a_min=s.parsed.a_min * scale, a_max=s.parsed.a_max * scale) @@ -249,21 +259,21 @@ def op_lut_rules(s: QuantInfo): X = s.args[0] offset = s.from_np_data(np.array(alpha, "int")) - indices = op.add(X, offset).like(X) - indices = op.clip(indices, a_min=0, a_max=2*alpha).like(X) #a_max=alpha+1) - indices = op.cast(indices, dtype="int32") + indices = infer_single(opclass.add(X, offset)).like(X) + indices = infer_single(opclass.clip(indices, a_min=0, a_max=2*alpha)).like(X) #a_max=alpha+1) + indices = infer_single(MRT_OP_MAP[AS_TYPE](indices, dtype="int32")) # arg_min, arg_max = -s.data, s.data # if s.is_op(EXP): # arg_max = min(math.log(s.data), arg_max) - op_inp = np.arange(-alpha, alpha+1) / s.args[0].scale + op_inp = np.arange(-alpha, alpha+1) / s.args[0].extra_attrs.get("scale", -1) table = inference.run(s, [ tvm.nd.array(op_inp), ]) table = np.clip(table.numpy(), a_min=-s.data, a_max=s.data) # table = np.reshape(table, (-1, 1)) oscale = s.precision_to_scale(LUT_OUT_PREC) weight = s.from_np_data(table * oscale) - out = op.adv_index(weight, indices).like(s) + out = infer_single(MRT_OP_MAP[ADV_INDEX](weight, indices)).like(s) # out.scale = s.precision_to_scale(LUT_INP_PREC) return out @@ -280,22 +290,22 @@ def softmax_scale_rules(s: QuantInfo): def op_softmax_rules(s: QuantInfo): lambd = 10 X = s.args[0] # get requant rule op - Xp = X.attrs["precision"] - Xs = X.scale #X.attrs["precision"] + Xp = X.extra_attrs["precision"] + Xs = X.extra_attrs["scale"] #X.attrs["precision"] axis = s.attrs["axis"] alpha = int(lambd * Xs) var = s.from_np_data(np.array(alpha, "int")) - max_axis = op.max_axis(X, axis = axis, keepdims=True) - offset = op.sub(max_axis, var) - offset = op.pclip(offset, precision=Xp) + max_axis = infer_single(opclass.max_axis(X, dim=axis, keepdim=True)) + offset = infer_single(opclass.sub(max_axis, var)) + offset = infer_single(MRT_OP_MAP[PCLIP](offset, precision=Xp)) offset.set_extra_attrs(precision=Xp) - norm = op.sub(X, offset) - norm = op.nn_relu(norm) - norm = op.pclip(norm, precision=Xp) + norm = infer_single(opclass.sub(X, offset)) + norm = infer_single(opclass.relu(norm)) + norm = infer_single(MRT_OP_MAP[PCLIP](norm, precision=Xp)) norm.set_extra_attrs(precision=Xp) # TODO: norm = op.cast(norm, dtype="int32") - norm = op.cast(norm, dtype="int32") + norm = infer_single(MRT_OP_MAP[AS_TYPE](norm, dtype="int32")) op_inp = np.arange(0, alpha+1) / Xs table = np.exp(op_inp) @@ -304,21 +314,21 @@ def op_softmax_rules(s: QuantInfo): weight = np.round(table) # weight = np.transpose(weight, (1, 0)) weight = s.from_np_data(weight) - out_lut = op.adv_index(weight, norm).like(s) - sum_lut = op.sum(out_lut, axis=axis, keepdims=True).like(out_lut) + out_lut = infer_single(MRT_OP_MAP[ADV_INDEX](weight, norm)).like(s) + sum_lut = infer_single(opclass.sum(out_lut, dim=axis, keepdim=True)).like(out_lut) oprec = min(SOFTMAX_PREC, 31 - tprec) oscale = bits_to_number(oprec) nd_oscale = s.from_np_data(np.array(oscale, "int")) - prob = op.mul(out_lut, nd_oscale) + prob = infer_single(opclass.mul(out_lut, nd_oscale)) - half_lut = op.rs_pclip(sum_lut, s.from_const_data(1), precision=31) + half_lut = infer_single(MRT_OP_MAP[RS_PCLIP](sum_lut, s.from_const_data(1), precision=31)) half_lut.set_extra_attrs(precision=31) - prob = op.add(prob, half_lut) - out = op.div(prob, sum_lut) - out = op.cast(out, dtype="int32") - out = op.cast(out, dtype="float32") - out = op.pclip(out, precision=oprec) + prob = infer_single(opclass.add(prob, half_lut)) + out = infer_single(opclass.div(prob, sum_lut)) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="int32")) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="float32")) + out = infer_single(MRT_OP_MAP[PCLIP](out, precision=oprec)) out.set_extra_attrs(scale=oscale, precision=oprec) return out @@ -330,7 +340,6 @@ def op_softmax_rules(s: QuantInfo): scale_rule=softmax_scale_rules ) -@dataclass(repr=False) class Discretor(QuantInfo): """ does operation -> out @@ -357,11 +366,15 @@ class Discretor(QuantInfo): output precision <- precision(target) output scale <- scale """ + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, **kw): if self.is_variable(): - return + return self.graph elif self.is_op(TUPLE): - return + return self.graph orig_names = [a.name for a in self.args] @@ -378,14 +391,14 @@ def __call__(self, **kw): # requant input to specific precision arg_dts = _DISCRETE_REQUANT_RULES[self.op_name](self) for i, arg in enumerate(self.args): - self.args[i] = arg.rescale(arg_dts[i]) + self.args[i] = self.from_symbol(arg).rescale(arg_dts[i]) # calculate the F function - out = _DISCRETE_OP_RULES[self.op_name](self).like( - self, extra_attrs=self.extra_attrs) + out = _DISCRETE_OP_RULES[self.op_name](self.graph).like( + self.graph, extra_attrs=self.extra_attrs) # calculate the output data's scale - out.scale = INFER_SCALE_RULES[self.op_name](out) + out.set_extra_attrs(scale = INFER_SCALE_RULES[self.op_name](out)) new = op.subgraph(out, inames=[a.name for a in self.args]) # self.is_op(EXP) and raw_print(new) # out.scale = infer_scale(new) @@ -397,7 +410,8 @@ def __call__(self, **kw): # out = op.pclip(out, precision=target_precision).like( # out, extra_attrs=out.extra_attrs) # out.precision = target_precision - out.precision = self.scale_to_precision(out.scale) + out.set_extra_attrs(precision = self.scale_to_precision(out.extra_attrs.get("scale", -1))) + # TODO: add skip for some operators # same_scale = all([a.scale == out.scale for a in self.args]) diff --git a/python/mrt/quantization/fixed_point.py b/python/mrt/quantization/fixed_point.py index 087798a..6dd07df 100644 --- a/python/mrt/quantization/fixed_point.py +++ b/python/mrt/quantization/fixed_point.py @@ -4,7 +4,9 @@ import numpy as np from dataclasses import dataclass -from mrt.mir import op +from mrt.mir import op, opclass +from mrt.mir.optype import infer_single +from mrt.mir.opclass import MRT_OP_MAP from mrt.mir.opns import * from mrt.mir.symbol import filter_operators from mrt.mir.attrs import PClipAttrs, RequantAttrs @@ -15,7 +17,6 @@ from mrt.common.config import _BaseConfig from mrt.common.utils import number_to_bits -from .transform import Transformer logger = logging.getLogger("exporter") @@ -51,8 +52,11 @@ class ExporterConfig(_BaseConfig): use_int_requant=True, use_int_dtype=True) -@dataclass(repr=False) class Exporter(QuantInfo): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def map_int_requant(self): """ requant(X, rescale) = X * rescale @@ -65,7 +69,7 @@ def map_int_requant(self): precision, which follows precision max bit limit. """ - X: FixPoint = self.args[0] + X: FixPoint = self.from_symbol(self.args[0]) rescale = self.parsed.rescale anno_bit = WithPrecision.MAX_BIT // 2 @@ -82,29 +86,30 @@ def map_int_requant(self): if X.precision > anno_bit: # recalculate exp + exp = exp + (X.precision - anno_bit) rs_bit = X.from_const_data(X.precision - anno_bit) - X = op.right_shift(X, rs_bit).like(self) + X_op = infer_single(opclass.right_shift(X.graph, rs_bit)).like(self.graph) + X = self.from_symbol(X_op) X.precision = anno_bit assert frac >= 1 assert exp <= 0 frac_sym = X.from_const_data(frac) - out = op.mul(X, frac_sym).like(self) + out = infer_single(opclass.mul(X.graph, frac_sym)).like(self.graph) - exp_sym = out.from_const_data(-exp) + exp_sym = self.from_symbol(out).from_const_data(-exp) if ExporterConfig.G().use_clip: if ExporterConfig.G().use_pclip: - out = op.rs_pclip(out, exp_sym, - precision=self.precision) + out = infer_single(MRT_OP_MAP[RS_PCLIP](out, exp_sym, precision=self.precision)) else: pos = self.int_max() - out = op.right_shift(out, exp_sym).like(self) - out = op.clip(out, min=-pos, max=pos).like(self) + out = infer_single(opclass.right_shift(out, exp_sym)).like(self.graph) + out = infer_single(opclass.clip(out, min=-pos, max=pos)).like(self.graph) else: - out = op.right_shift(out, exp_sym).like(self) + out = infer_single(opclass.right_shift(out, exp_sym)).like(self.graph) return out def process(self): @@ -114,7 +119,7 @@ def process(self): if G.use_int_dtype: G.use_round = True - out = self + out = self.graph if self.is_param() and G.use_round: data = np.round(self.numpy()) if G.use_int_dtype: @@ -123,60 +128,64 @@ def process(self): pos = self.int_max() if self.is_op(REQUANT): - if G.use_int_requant and (not self.args[0].is_input()): + if G.use_int_requant and (not self.from_symbol(self.args[0]).is_input()): out = self.map_int_requant() else: # use float multipy to map requant rescale = self.parsed.rescale rescale = self.from_const_data(rescale) - out = op.mul(self.args[0], rescale) + out = infer_single(opclass.mul(self.args[0], rescale)) if G.use_clip: - out = op.clip(out, min=-pos, max=pos) + out = infer_single(opclass.clip(out, min=-pos, max=pos)) if not G.use_int_dtype and G.use_round: orig_dtype = out.dtype - out = op.cast(out, dtype="int32") - out = op.cast(out, dtype=orig_dtype) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="int32")) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype=orig_dtype)) if not G.use_clip: if self.is_op(PCLIP): out = self.args[0] elif self.is_op(RS_PCLIP): - out = op.right_shift(*self.args) + out = infer_single(opclass.right_shift(*self.args)) elif not G.use_pclip: if self.is_op(PCLIP): out = self.args[0] elif self.is_op(RS_PCLIP): - out = op.right_shift(*self.args) - out = op.clip(out, min=-pos, max=pos) + out = infer_single(opclass.right_shift(*self.args)) + out = infer_single(opclass.clip(out, min=-pos, max=pos)) return out def __call__(self, **kw): if not self.precision_defined: logger.warning(f"symbol: {self.name} is ignored without precision defined.") - return self + return self.graph self.validate_precision() - out = self.process().like(self, extra_attrs=self.extra_attrs) + out = self.process().like(self.graph, extra_attrs=self.extra_attrs) # TODO: add precision int max validate # if self.is_param(): # absmax = np.abs(out.numpy()).max() # assert absmax - 0.01 <= out.int_max() return out -@dataclass(repr=False) +# TODO: deprecated? class Simulator(QuantInfo): - def round(self, out: Transformer): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + + def round(self, out: Symbol): # data_0_5 = self.from_const_data(0.5) # out = op.add(out, data_0_5) # out = op.ceil(out) orig_dtype = out.dtype - out = op.cast(out, dtype="int32") - out = op.cast(out, dtype=orig_dtype) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="int32")) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype=orig_dtype)) return out def __call__(self, with_clip=False, with_round=False, **kw): - out: Transformer = self + out: Symbol = self.graph if self.is_input(): """ input is the original float data, skip. """ return out @@ -189,20 +198,24 @@ def __call__(self, with_clip=False, with_round=False, **kw): if self.is_op(REQUANT): rescale = self.parsed.rescale rescale = self.from_const_data(rescale) - out = op.mul(out, rescale) + out = infer_single(opclass.mul(out, rescale)) if with_round: out = self.round(out) if with_clip: pos = self.int_max() # relax api from a_min/a_max to min/max - out = op.clip(out, min=-pos, max=pos) + out = infer_single(opclass.clip(out, min=-pos, max=pos)) # print(out) # sys.exit() - return out.like(self) + return out.like(self.graph) -@dataclass(repr=False) +# TODO: deprecated? class FixPoint(QuantInfo): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def map_requant(self) -> FixPoint: if (self.args[0]).is_input(): return self @@ -213,18 +226,17 @@ def map_requant(self) -> FixPoint: anno_bit = WithPrecision.MAX_BIT // 2 if X.precision > anno_bit: rs_bit = X.from_const_data(X.precision - anno_bit) - X = op.right_shift(X, rs_bit).like(self) + X = infer_single(opclass.right_shift(X, rs_bit).like(self)) X.precision = anno_bit frac, exp = cvm_float(self.parsed.rescale, anno_bit) assert frac >= 1 assert exp <= 0 frac_sym = X.from_const_data(frac) - out = op.mul(X, frac_sym).like(self) + out = infer_single(opclass.mul(X, frac_sym)).like(self) exp_sym = out.from_const_data(-exp) - out = op.rs_pclip(out, exp_sym, - precision=self.precision) + out = infer_single(MRT_OP_MAP[RS_PCLIP](out, exp_sym, precision=self.precision)) # pos = self.int_max() # out = op.right_shift(out, exp_sym).like(self) # out = op.clip(out, a_min=-pos, a_max=pos).like(self) @@ -235,7 +247,7 @@ def map_pclip(self) -> FixPoint: X: FixPoint = self.args[0] pos = self.int_max() out = X - out = op.pclip(X, precision=self.precision).like(self) + out = infer_single(MRT_OP_MAP[PCLIP](X, precision=self.precision)).like(self) # out = op.clip(X, a_min=-pos, a_max=pos).like(self) return out diff --git a/python/mrt/quantization/fuse.py b/python/mrt/quantization/fuse.py index 31360a5..b664522 100644 --- a/python/mrt/quantization/fuse.py +++ b/python/mrt/quantization/fuse.py @@ -3,7 +3,7 @@ import numpy as np -from mrt.mir import op +from mrt.mir import opclass, optype from mrt.mir.opns import * from mrt.mir.symbol import * from mrt.mir.attrs import * @@ -11,61 +11,66 @@ from mrt.runtime import inference from mrt.common.utils import N, product -from .transform import Transformer +from mrt.mir.symbol_pass import SymbolTransformer -# TODO: add op pass register map. -class FuseDropout(Transformer): +class FuseDropout(SymbolTransformer): #out = filter_operators(DROP_OUT)(__call__) # def out(): @filter_operators(DROP_OUT) def __call__(self, **kwargs): return self.args[0] -class FuseConstant(Transformer): +class FuseIdentity(SymbolTransformer): + @filter_operators(IDENTITY) + def __call__(self, **kwargs): + return self.args[0] + +class FuseConstant(SymbolTransformer): threshold: typing.ClassVar[float] = 1e-5 def np_is_zero(self, data) -> float: return np.abs(data).max() < self.threshold - def __call__(self: Transformer, **kw): - if self.is_operator() and all([c.is_param() for c in self.args]): + def __call__(self: SymbolTransformer, **kw): + if self.is_operator() and all([self.from_symbol(c).is_param() for c in self.args]): data = inference.run_single( - self, [a.numpy() for a in self.args]) + self.graph, [self.from_symbol(a).numpy() for a in self.args]) return self.as_parameter(data) elif self.is_op(ADD, SUB): # , BIAS_ADD): strips = [] for arg in self.args: - if arg.is_param() and self.np_is_zero(arg.numpy()): + if self.from_symbol(arg).is_param() and self.np_is_zero(self.from_symbol(arg).numpy()): # if arg.is_param() and np.abs(arg.numpy()).max() == 0: strips.append(arg) args = [a for a in self.args if a not in strips] if len(args) == 1: return args[0] elif self.is_op(SLICE_LIKE): - if not self.args[0].is_param(): + if not self.from_symbol(self.args[0]).is_param(): return a, b = self.args arg1 = np.zeros(b.shape, b.dtype) data = inference.run_single( - self, [a.numpy(), np.zeros(b.shape, b.dtype)]) + self.graph, [self.from_symbol(a).numpy(), np.zeros(b.shape, b.dtype)]) return self.as_parameter(data) elif self.is_op(REQUANT): if self.parsed.rescale == 1: return self.args[0] elif self.is_op(ZEROS_LIKE, ONES_LIKE): - data = inference.run_single(self, []) + data = inference.run_single(self.graph, []) return self.as_parameter(data) -class FuseBatchNorm(Transformer): +class FuseBatchNorm(SymbolTransformer): @filter_operators(BATCH_NORM) def __call__(self, **kw): X, gamma, beta, mean, var = self.args + X = self.from_symbol(X) parsed: BatchNormAttrs = self.parsed - gamma, beta = gamma.numpy(), beta.numpy() - mean, var = mean.numpy(), var.numpy() + gamma, beta = self.from_symbol(gamma).numpy(), self.from_symbol(beta).numpy() + mean, var = self.from_symbol(mean).numpy(), self.from_symbol(var).numpy() # print(gamma.shape, beta.shape, mean.shape, var.shape) assert parsed.axis == 1 @@ -90,42 +95,42 @@ def __call__(self, **kw): # (A * W) * gamma + bias # A * (W * gamma) + bias - W_data = W.numpy() * gamma.reshape(K, 1, 1, 1) - W_sym = W.from_np_data(W_data) - out = op.nn_conv2d(A, W_sym, **X.attrs) + W_data = self.from_symbol(W).numpy() * gamma.reshape(K, 1, 1, 1) + W_sym = self.from_symbol(W).from_np_data(W_data) + out = optype.infer_single(opclass.conv2d(A, W_sym, **X.attrs)) elif X.is_op(DENSE): A, W = X.args dense_parsed: DenseAttrs = X.parsed # (A * W) * gamma + bias # A * (W * gamma) + bias - W_data = W.numpy() * gamma.reshape(K, 1) - W_sym = W.from_np_data(W_data) - out = op.nn_dense(A, W_sym, **X.attrs) + W_data = self.from_symbol(W).numpy() * gamma.reshape(K, 1) + W_sym = self.from_symbol(W).from_np_data(W_data) + out = optype.infer_single(opclass.dense(A, W_sym, **X.attrs)) else: reshp = [s if i == parsed.axis else 1 \ for i, s in enumerate(X.shape)] W = X.from_np_data(gamma.reshape(reshp)) - out = op.mul(X, W) + out = optype.infer_single(opclass.mul(X.graph, W)) bias = bias.reshape([s if i == parsed.axis else 1 \ for i, s in enumerate(out.shape)]) - B = out.like(self).from_np_data(bias) - out = op.add(out, B) - # out = op.bias_add(out, B, axis=parsed.axis) - return out.like(self) + B = self.from_symbol(out.like(self.graph)).from_np_data(bias) + out = opclass.add(out, B) + out = optype.infer_single(out) + return out.like(self.graph) -class FuseTupleGetItem(Transformer): +class FuseTupleGetItem(SymbolTransformer): @filter_operators(TUPLE_GET_ITEM) def __call__(self, **kw): X: Symbol = self.args[0] - if X.is_op(BATCH_NORM, DROP_OUT): + if self.from_symbol(X).is_op(BATCH_NORM, DROP_OUT): return X # assert X.is_op(BATCH_NORM, DROP_OUT), X.name # assert self.parsed.index == 0 # return X -class FuseAvgPool2D(Transformer): +class FuseAvgPool2D(SymbolTransformer): def __call__(self, **kw): out = self._fuse_adaptive_avg_pool2d() out = out or self._fuse_avg_pool2d() @@ -133,7 +138,7 @@ def __call__(self, **kw): @filter_operators(AVG_POOL2D) def _fuse_avg_pool2d(self): - X: Transformer = self.args[0] + X: Symbol = self.args[0] parsed: AvgPool2DAttrs = self.parsed assert parsed.layout == "NCHW" # TODO: ignore for unstrict mode @@ -148,15 +153,15 @@ def _fuse_avg_pool2d(self): "channels": X.shape[1], } W_shape = (X.shape[1], 1, *parsed.pool_size) - W = X.from_np_data(np.full( + W = self.from_symbol(X).from_np_data(np.full( W_shape, 1 / product(parsed.pool_size))) - out = op.nn_conv2d(X, W, **attrs) - return out.like(self) + out = optype.infer_single(opclass.conv2d(X, W, **attrs)) + return out.like(self.graph) @filter_operators(ADAPTIVE_AVG_POOL2D) def _fuse_adaptive_avg_pool2d(self): - X: Transformer = self.args[0] + X: Symbol = self.args[0] parsed: AdaptiveAvgPool2DAttrs = self.parsed assert parsed.layout == "NCHW" ins = X.shape[2:] @@ -168,14 +173,15 @@ def _fuse_adaptive_avg_pool2d(self): assert len(X.shape) == 4 if all([s == 1 for s in parsed.output_size]): scale = np.array(1 / np.prod(X.shape[-2:])) - out = op.sum(X, axis=list(range(4))[-2:], keepdims=True) + out = optype.infer_single(opclass.sum(X, dim=list(range(4))[-2:], keepdim=True)) scale = self.from_np_data(scale.astype(X.dtype)) - return op.mul(out, scale).like(self) + out = optype.infer_single(opclass.mul(out, scale)) + return out.like(self.graph) elif ous[0] > ins[0] or ous[1] > ins[1]: assert all([s == 1 for s in ins]) - out = op.repeat(X, repeats=ous[0], axis=-2) - out = op.repeat(out, repeats=ous[1], axis=-1) - return out.like(self) + out = optype.infer_single(opclass.repeat(X, repeats=ous[0], axis=-2)) + out = optype.infer_single(opclass.repeat(out, repeats=ous[1], axis=-1)) + return out.like(self.graph) # calculate the attributes refers to: # https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work @@ -191,23 +197,23 @@ def _fuse_adaptive_avg_pool2d(self): "channels": X.shape[1], } W_shape = (X.shape[1], 1, *kernel) - W = X.from_np_data(np.full(W_shape, 1 / product(kernel))) - out = op.nn_conv2d(X, W, **attrs) - return out.like(self) + W = self.from_symbol(X).from_np_data(np.full(W_shape, 1 / product(kernel))) + out = optype.infer_single(opclass.conv2d(X, W, **attrs)) + return out.like(self.graph) -class FuseNaiveSoftmax(Transformer): +class FuseNaiveSoftmax(SymbolTransformer): def __call__(self, **kw): - return self # not fuse pass + return self.graph # not fuse pass if self.is_op(SOFTMAX, LOG_SOFTMAX): return self.args[0] - assert self.is_variable() or not self.args[0].is_op(SOFTMAX, LOG_SOFTMAX) - return self + assert self.is_variable() or not self.from_symbol(self.args[0]).is_op(SOFTMAX, LOG_SOFTMAX) + return self.graph -class FuseMean(Transformer): +class FuseMean(SymbolTransformer): @filter_operators(MEAN) def __call__(self, **kw): - X: Transformer = self.args[0] + X: Symbol = self.args[0] # max_axis = len(X.shape) # axis = X.attrs.get("axis", None) # axis = axis or [i for i in range(max_axis)] @@ -217,13 +223,13 @@ def __call__(self, **kw): # axis = [a for a in range(max_axis) if a not in axis] # axis_len = product([X.shape[a] for a in axis]) - out = op.sum(X, **self.attrs) + out = optype.infer_single(opclass.sum(X, **self.attrs)) scale = self.from_np_data(np.array( 1. * product(out.shape) / product(X.shape))) - out = op.mul(out, scale) - return out.like(self) + out = optype.infer_single(opclass.mul(out, scale)) + return out.like(self.graph) -class FuseLeakyReLU(Transformer): +class FuseLeakyReLU(SymbolTransformer): @filter_operators(LEAKY_RELU) def __call__(self, **kw): """ Customized rewrite pass Introduction. @@ -234,25 +240,27 @@ def __call__(self, **kw): LeakyReLU(X) = relu(X) - slope*relu(-X) """ alpha = self.from_const_data(self.parsed.alpha) - X: Transformer = self.args[0] - out = op.nn_relu(op.negative(X)) - out = op.mul(alpha, out) - out = op.sub(op.nn_relu(X), out) - return out.like(self) + X: Symbol = self.args[0] + out = optype.infer_single(opclass.negative(X)) + out = optype.infer_single(opclass.relu(out)) + out = optype.infer_single(opclass.mul(alpha, out)) + out = optype.infer_single(opclass.sub(optype.infer_single(opclass.relu(X)), out)) + return out.like(self.graph) -class FuseDivide(Transformer): +class FuseDivide(SymbolTransformer): @filter_operators(DIV) def __call__(self, **kw): """ Transform div to mul if possible. """ - A: Transformer = self.args[0] - B: Transformer = self.args[1] - assert B.is_param(), B - B = B.from_np_data(1. / B.numpy()) - return op.mul(A, B).like(self) + A: Symbol = self.args[0] + B: Symbol = self.args[1] + assert self.from_symbol(B).is_param(), B + B = self.from_symbol(B).from_np_data(1. / self.from_symbol(B).numpy()) + out = optype.infer_single(opclass.mul(A, B)) + return out.like(self.graph) # move to fuse constant -# class FuseNaiveMathmatic(Transformer): +# class FuseNaiveMathmatic(SymbolTransformer): # def __call__(self): # if self.is_op(BIAS_ADD): # X, B = self.args diff --git a/python/mrt/quantization/precision.py b/python/mrt/quantization/precision.py index b933b6a..0cbfb47 100644 --- a/python/mrt/quantization/precision.py +++ b/python/mrt/quantization/precision.py @@ -1,12 +1,11 @@ from __future__ import annotations import typing -from dataclasses import dataclass import math import numpy as np -from mrt.mir import op +from mrt.mir import op, optype, opclass, opns from mrt.mir.opns import * from mrt.mir.symbol import Symbol, visit, transform @@ -14,16 +13,19 @@ number_to_bits, count_to_bits, bits_to_number from mrt.common.types import ParametersT -from .transform import Transformer +from mrt.mir.symbol_pass import SymbolBridge, SymbolTransformer __ALL__ = [ "WithPrecision", "InferPrecision", "QuantizedInfo", ] -@dataclass(repr=False) -class WithPrecision(Symbol): +class WithPrecision(SymbolBridge): MAX_BIT: typing.ClassVar[int] = 32 + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def _validate_precision(cls, prec, msg=None): assert isinstance(prec, int), self.precision @@ -106,13 +108,13 @@ def _add_rules(f: RulesFuncT): return f return _add_rules -_infer_mul: RulesFuncT = lambda s: sum([c.precision for c in s.args[:2]]) +_infer_mul: RulesFuncT = lambda s: sum([c.extra_attrs.get("precision", -1) for c in s.args[:2]]) """ conv2d may has 3-args, use prefix-2. """ -_infer_max: RulesFuncT = lambda s: max([c.precision for c in s.args]) +_infer_max: RulesFuncT = lambda s: max([c.extra_attrs.get("precision", -1) for c in s.args]) def _infer_index(s: WithPrecision, index: int): - return s.args[index].precision + return s.args[index].extra_attrs.get("precision", -1) prec_rules(TUPLE)(_infer_max) prec_rules(MAX_AXIS)(_infer_max) @@ -166,19 +168,22 @@ def _infer_right_shift(s: WithPrecision): A, B = s.args[0], s.args[1] assert B.is_param() b_prec = InferPrecision.bind(B) - return A.precision - b_prec + return A.extra_attrs.get("precision", -1) - b_prec @prec_rules(REQUANT, PCLIP, RS_PCLIP) def _infer_attr_prec(s: WithPrecision): assert s.parsed.precision == s.precision return s.parsed.precision -@dataclass(repr=False) -class PrecisionRevisor(WithPrecision, Transformer): +class PrecisionRevisor(WithPrecision, SymbolTransformer): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, **kw): out = self if out.is_input(): - return + return out.graph elif out.is_op(REQUANT, PCLIP): assert out.precision == out.parsed.precision, f"{out.name} out_prec:{out.precision}, out_parsed_prec:{out.parsed.precision}" elif out.is_op(RS_PCLIP): @@ -202,12 +207,13 @@ def __call__(self, **kw): # print("infered prec:", oprec) if out.precision_defined and oprec > out.precision: out.precision, oprec = oprec, out.precision - out = op.pclip(out, precision=oprec).like( - out, extra_attrs=out.extra_attrs) + out = out.from_symbol(optype.infer_single(opclass.MRT_OP_MAP[opns.PCLIP]( + out.graph, precision=oprec)).like( + out.graph, extra_attrs=out.extra_attrs)) out.precision = oprec out.validate_precision() - return out + return out.graph # def cvm_infer_single_precision( # symbol: WithPrecision, params: ParametersT) -> int: diff --git a/python/mrt/quantization/scaler.py b/python/mrt/quantization/scaler.py index 8cd058d..0dac575 100644 --- a/python/mrt/quantization/scaler.py +++ b/python/mrt/quantization/scaler.py @@ -7,8 +7,13 @@ from mrt.mir.opns import * from mrt.mir.symbol import * -@dataclass(repr=False) -class WithScale(Symbol): +from mrt.mir.symbol_pass import SymbolBridge + +class WithScale(SymbolBridge): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def _validate_scale(cls, scale, msg=None): if isinstance(scale, (list, tuple)): @@ -55,13 +60,13 @@ def _add_rules(f: ScaleRulesT): return _add_rules def scale_index(s: WithScale, index: int): - return s.args[index].scale + return s.args[index].extra_attrs.get("scale", -1) def scale_nn(s: WithScale): - return s.args[0].scale * s.args[1].scale + return s.args[0].extra_attrs.get("scale", -1) * s.args[1].extra_attrs.get("scale", -1) def scale_identity(s: WithScale): - return s.args[0].scale + return s.args[0].extra_attrs.get("scale", -1) def infer_scale(symbol: WithScale): def _infer(sym: Symbol): diff --git a/python/mrt/quantization/segement.py b/python/mrt/quantization/segement.py index 38d3f5b..f1b3c31 100644 --- a/python/mrt/quantization/segement.py +++ b/python/mrt/quantization/segement.py @@ -3,10 +3,10 @@ from dataclasses import dataclass, field from mrt.mir.symbol import * -from mrt.mir import op, opns, helper +from mrt.mir import op, opns, helper, optype, opclass from .scaler import WithScale -from .transform import RunOnce +from mrt.mir.symbol_pass import RunOnce _SCALE_CONSTANT_OPS = [ opns.VAR, @@ -23,27 +23,30 @@ opns.CLIP, opns.AS_TYPE, ] -@dataclass(repr=False) class Spliter(RunOnce): head: typing.Optional[dict] = None head_params: typing.Optional[typing.Dict[str, OpNumpyT]] = None seg_names: typing.List[str] = field(default_factory=list) + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, **kwargs): """ Auto split the model. """ refs = { self.name: 1 } # add refs for root symbol def _collect_refs(sym: Spliter): refs.setdefault(sym.name, 0) - if sym.is_variable(): + if self.from_symbol(sym).is_variable(): return for a in sym.args: refs.setdefault(a.name, 0) refs[a.name] += 1 - visit(self, _collect_refs) + visit(self.graph, _collect_refs) sym_map = {} sym_status = {} - heads = [self] + heads = [self.graph] """ status code: 1 means current symbol has been scaned and sub childs have been added into scan list. @@ -102,7 +105,7 @@ def _collect_refs(sym: Spliter): def _split(sym: Spliter): return op.as_variable(sym) \ if sym.name in self.seg_names else sym - head = transform(self, _split) + head = transform(self.graph, _split) self.head = dump_json(head) self.head_params = {} @@ -114,21 +117,34 @@ def _update_params(sym: Symbol): # helper.format_print(head, self.head_params) - return op.Tuple(*outs).like(self) + # export to symbol_op Spliter_%N + out = optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*outs)).like(self.graph) + out.set_extra_attrs(seg_names=self.seg_names) + out.set_extra_attrs(head=self.head) + out.set_extra_attrs(head_params=self.head_params) + return out -@dataclass(repr=False) class Merger(WithScale, RunOnce): - def __call__(self, spliter: Spliter, **kw): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + + def __call__(self, spliter: Symbol, **kwargs): assert self.op_name == opns.TUPLE - tail_outs = dict(zip(spliter.seg_names, self.args)) + + head = kwargs['ptr']["head"] + head_params = kwargs['ptr']["head_params"] + seg_names = kwargs['ptr']["seg_names"] + + tail_outs = dict(zip(seg_names, self.args)) # print(spliter.seg_names) - assert spliter.head is not None + assert head is not None head_params = {k: to_ndarray(v) \ - for k, v in spliter.head_params.items()} + for k, v in head_params.items()} # head_params.update(self.params) - head = load_json(spliter.head, params=head_params) + head = load_json(head, params=head_params) # helper.format_print(head, head_params) @@ -139,6 +155,7 @@ def _merge(sym: Symbol): return sym out = transform(head, _merge) - return out.like(self, params={ **head_params, **self.params }) + self.params = { **head_params, **self.params } + return out.like(self.graph) diff --git a/python/mrt/quantization/transform.py b/python/mrt/quantization/transform.py deleted file mode 100644 index 8fa7e47..0000000 --- a/python/mrt/quantization/transform.py +++ /dev/null @@ -1,150 +0,0 @@ -from __future__ import annotations - -import typing -from functools import wraps -from dataclasses import dataclass, field - -import numpy as np - -from mrt.mir.symbol import * - -from mrt.mir import op, opns -from mrt.mir.attrs import _BaseAttrs, parse_attrs - -from mrt.common.utils import N - -@dataclass(repr=False) -class WithParameters(Symbol): - parsed: _BaseAttrs = field(repr=False) - params: ParametersT = field(repr=False) - """ Parameters should not be changed in transformer, - use copy mode instead to avoid possible errors. - - deep copy params in trace `checkpoint_run` api. - """ - - @classmethod - def update_dict(cls, data_dict: dict, **kwargs) -> dict: - data_dict.update(kwargs) - parsed = parse_attrs( - data_dict["op_name"], data_dict["attrs"]) - return super().update_dict(data_dict, parsed=parsed) - - def __repr__(self, **attrs): - if self.is_param(): - attrs["absmax"] = np.abs(self.numpy()).max(initial=0) - return super().__repr__(**attrs) - - def ndarray(self) -> OpOutputT: - return to_ndarray(self.numpy()) - - def numpy(self) -> OpNumpyT: - assert self.is_param(), f"{self.name} is not parameter." - data = self.params[self.name] - assert isinstance(data, (tuple, list, np.ndarray)), \ - f"param:{self.name} not OpNumpyT, get {type(data)}" - return data - - return to_numpy(self.ndarray()) - - def as_parameter(self, data: OpNumpyT): - def _f(data, dtype): - if isinstance(data, list): - assert len(data) == len(dtype) - return [_f(d, t) for d, t in zip(data, dtype)] - assert isinstance(data, np.ndarray), type(data) - return data.astype(dtype) - - self.params[self.name] = _f(data, self.dtype) - return op.as_variable(self) - - def from_const_data(self, data: typing.Union[int, float]) -> WithParameters: - return self.from_np_data(data) - - def from_np_data(self, data: np.ndarray, prefix="%") -> Symbol: - name = N.n(prefix=prefix) - # some data is np.float/int type, use np.array to wrap it. - data = np.array(data) - self.params[name] = data.astype(self.dtype) - return op.variable(name, data.shape, self.dtype).like(self) - - def is_input(self) -> bool: - return op.is_input(self, self.params) - def is_param(self) -> bool: - return op.is_param(self, self.params) - def is_variable(self) -> bool: - return op.is_variable(self, self.params) - def is_operator(self) -> bool: - return op.is_operator(self, self.params) - -TransformerT = typing.Callable[[Graph], Graph] -""" Transformer Callback Function Type, - inherited from WithParameters. -""" - -@dataclass(repr=False) -class Transformer(WithParameters): - """ Symbol Transformer """ - - RUN_ONCE: typing.ClassVar[bool] =False - """ whether to run callback once? """ - - # def to_dict(self, **kwargs): - # """ override to dict, since transformer may want to - # access the previous tfm. Thus, the next - # update_dict has the `origin` key by default. - # """ - # data = super().to_dict(**kwargs) - # data["extra_attrs"]["origin"] = self - # return data - - @classmethod - def get_transformer(cls, name: typing.Optional[str] = None): - name = name or cls.__name__ - def _func(graph: Symbol, params: ParametersT, **kwargs): - def _run(sym: Symbol): - # use current cls to apply transform, this - # may loss some information from origin - # symbol, so record as `origin` in call. - out = cls.base(sym, params=params) - out = out(origin=sym, **kwargs) or out - assert isinstance(out, cls), ( - "transform output type should be {}," - " but get {}" - ).format(cls, type(out)) - return out - _run.__name__ = name - with N(name): - return _run(graph) if cls.RUN_ONCE \ - else transform(graph, _run) - _func.__name__ = name - return _func - - # @classmethod - # def apply(cls, *args, **kw): - # """ Static apply function to generator transformer pass. - - # All the parameters are used to invoke `call` method. - # """ - # def _tfm(sym: Symbol, params: ParametersT): - # ins = cls.base(sym, params=params) - # out = ins(*args, **kw) or ins - # assert isinstance(out, cls), ( - # "expected {}, but get {}" - # ).format(cls, type(out)) - # return out - - # _tfm.__name__ = cls.__name__ - # return _tfm - - def __call__(self, *args, **kw) -> typing.Optional[Transformer]: - """ - Parameters: - origin: original symbol passed from last transformer. - """ - raise NotImplementedError() - -@dataclass(repr=False) -class RunOnce(Transformer): - RUN_ONCE: typing.ClassVar[bool] = True - diff --git a/python/mrt/runtime/inference.py b/python/mrt/runtime/inference.py index f520018..e2c76fd 100644 --- a/python/mrt/runtime/inference.py +++ b/python/mrt/runtime/inference.py @@ -2,24 +2,25 @@ import numpy as np from mrt.mir.symbol import * +from mrt.mir.mhsymbol import MultiHeadSymbol from mrt.mir.opns import * from mrt import frontend as ft -from mrt.quantization.transform import WithParameters +from mrt.mir.symbol_pass import SymbolParameters from mrt.mir import op def run_single( - sym: WithParameters, + sym: SymbolParameters, args_data: typing.List[OpNumpyT], **kwargs) -> OpNumpyT: assert op.is_operator(sym), sym sym = op.retrieve_operator(sym) if sym.is_op(TUPLE_GET_ITEM): - return args_data[0][sym.parsed.index] + return args_data[0][sym.attrs['index']] elif sym.is_op(REQUANT): # it's type is np.float32/64, use np.array to wrap. - return np.array(sym.parsed.rescale * args_data[0]) + return np.array(sym.attrs['rescale'] * args_data[0]) elif sym.is_op(ARANGE): args = [a.numpy().item() for a in args_data] return np.arange(*args, **sym.attrs) diff --git a/tests/mir/test.infer_pass.py b/tests/mir/test.infer_pass.py new file mode 100644 index 0000000..3d94e93 --- /dev/null +++ b/tests/mir/test.infer_pass.py @@ -0,0 +1,103 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_resnet18_model(): + """Get Resnet18 MRT Model""" + + # Load pre-trained ResNet18 + model = models.resnet18(weights='IMAGENET1K_V1') + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseBatchNorm(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseBatchNorm Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseAdaptiveAvgPool2D(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseAdaptiveAvgPool2D Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseTupleGetItem(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseTuple Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt > 0, f'ori model TupleGetItem op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseTupleGetItemPass = simple_pass.FuseTupleGetItemPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseTuple Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass ===") + mrt_graph, mrt_params = _get_resnet18_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseBatchNorm, test_InferPass_FuseAdaptiveAvgPool2D, test_InferPass_FuseTupleGetItem] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_div.py b/tests/mir/test.infer_pass_div.py new file mode 100644 index 0000000..547363c --- /dev/null +++ b/tests/mir/test.infer_pass_div.py @@ -0,0 +1,88 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_fasterrcnn_resnet50_fpn_model(): + """Get Fasterrcnn_resnet50_fpn MRT Model""" + + # Load pre-trained model + model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseDivide(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDivide Pass ===') + symlist = sx.sym2list(symbol) + + divide_op_cnt = 0 + for sym in symlist: + divide_op_cnt += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt > 0, f'ori model divide op cnt {divide_op_cnt} == zero!' + + # init FuseDivide Passer and execute visit + tfs : simple_pass.FuseDividePass = simple_pass.FuseDividePass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseDivide Pass ===') + rlts = sx.sym2list(symbol_passed) + divide_op_cnt_af = 0 + for sym in rlts: + # print(sym) + divide_op_cnt_af += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt_af==0, f'passed model divide op cnt {divide_op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Divide ===") + mrt_graph, mrt_params = _get_fasterrcnn_resnet50_fpn_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseDivide] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_mean.py b/tests/mir/test.infer_pass_mean.py new file mode 100644 index 0000000..d8586ec --- /dev/null +++ b/tests/mir/test.infer_pass_mean.py @@ -0,0 +1,89 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_shufflenet_model(): + """Get Shufflenet MRT Model""" + + # Load pre-trained + model = models.shufflenet_v2_x1_0(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseMean(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseMean Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt > 0, f'ori model mean op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseMeanPass = simple_pass.FuseMeanPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseMean Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Mean ===") + mrt_graph, mrt_params = _get_shufflenet_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseMean] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py new file mode 100644 index 0000000..004cb5a --- /dev/null +++ b/tests/mir/test.op_create.py @@ -0,0 +1,183 @@ +""" +Test script for Alexnet PyTorch to MRT conversion. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass + + +def test_op_func(): + X = opclass.var(name="var2", shape=(16, 128, 128), dtype="float") + ceil0 = opclass.ceil(X) + assert isinstance(ceil0, sx.Symbol), 'ceil0 isnot a symbol' + assert ceil0.op_name == opns.CEIL + assert len(ceil0.name) > 0 + + ceil1 = opclass.ceil(X, 'ceil_1') + assert ceil1.op_name == opns.CEIL + assert ceil1.name == 'ceil_1' + + return True + + +def test_create_conv2d_op(): + + X = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") + assert [shp for shp in X.shape] == [shp for shp in (1, 3, 224, 224)], f'Wrong X shape {X.shape}' + assert X.dtype == "float", f'Wrong X dtype {X.dtype}' + + # Symbol Init using opclass OP + conv2d_a = opclass.Conv2D(X, W, name='conv2d_a', strides=(2,2)) + assert isinstance(conv2d_a, sx.Symbol), 'conv2d_a isnot a symbol' + assert isinstance(conv2d_a, opclass.Conv2D), 'conv2d_a isnot a Conv2D' + + # attrs hint + assert conv2d_a.args != None + assert conv2d_a.attrs != None + assert conv2d_a.strides != None + + args = [X, W] + attrs = {'strides':(3,3)} + conv2d_f = opclass.conv2d(*args, **attrs) + assert isinstance(conv2d_f, opclass.Conv2D), 'conv2d_f isnot a Conv2D' + + print(f'Got {conv2d_a.name} strides: {conv2d_a.strides}') + print(f'Got {conv2d_a.name} padding: {conv2d_a.padding}') + print(f'Show {conv2d_a.name} {conv2d_a}') + + # test Conv2D clone mode + conv2d_b = conv2d_a.copy() + assert isinstance(conv2d_b, sx.Symbol), 'conv2d_b isnot a symbol' + assert isinstance(conv2d_b, opclass.Conv2D), 'conv2d_b isnot a Conv2D' + + assert conv2d_b.attrs == conv2d_a.attrs, f'a: {conv2d_b.attrs} != b: {conv2d_a.attrs}' + + # test Dict to Find Class and Init + conv2d_c = opclass.MRT_OP_MAP[opns.CONV2D](X, W, strides=(2,2)) + assert isinstance(conv2d_c, opclass.Conv2D), 'conv2d_c isnot a Conv2D' + + # test Variable clone mode + X1 = X.copy() + assert X1.shape == X.shape + assert X1.dtype == X.dtype + + # test: Symbol Compatible Mode + args = [X1, W] + attrs = {'strides':(3,3)} + + + # Symbol Compatible Init + conv2d_d = opclass.Conv2D(*args, name='conv2d_d', **attrs) + conv2d_e = opclass.Conv2D(*args, **attrs) + assert isinstance(conv2d_d, opclass.Conv2D), 'conv2d_d isnot a Conv2D' + assert isinstance(conv2d_e, opclass.Conv2D), 'conv2d_e isnot a Conv2D' + + # alias function Init + conv2d_f = opclass.conv2d(*args, **attrs) + assert isinstance(conv2d_f, opclass.Conv2D), 'conv2d_f isnot a Conv2D' + + return True + + +def test_create_symbol_graph(): + X0 = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W0 = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") + conv2d_a = opclass.Conv2D(X0, W0, name='conv2d_a', strides=(1,1)) + + W1 = opclass.var(shape=(16, 3, 12, 12,), dtype="float") + conv2d_b = opclass.Conv2D(conv2d_a, W1, name='conv2d_b', strides=(1,1)) + symlist = sx.sym2list(conv2d_b) + + assert symlist[0] == X0 + assert symlist[1] == W0 + + for id_ in range(len(symlist)): + print(id_, symlist[id_]) + + return True + + +def test_create_batch_norm_op(): + X = opclass.var(name="x", shape=(1, 32, 128, 128,), dtype="float") + Gamma = opclass.var(name="gamma", shape=(32,), dtype="float") + Beta = opclass.var(name="beta", shape=(32,), dtype="float") + Mean = opclass.var(name="mean", shape=(32,), dtype="float") + Var = opclass.var(name="var", shape=(32,), dtype="float") + batch_norm_a = opclass.BatchNorm(X, Gamma, Beta, Mean, Var, axis=1, epsilon=1e-4) + + # attrs hint + assert batch_norm_a.args != None + assert batch_norm_a.attrs != None + assert batch_norm_a.axis != 0 + + # test clone mode + batch_norm_b = batch_norm_a.copy() + assert isinstance(batch_norm_b, opclass.BatchNorm) + + assert batch_norm_a.attrs == batch_norm_b.attrs, f'a: {batch_norm_a.attrs} != b: {batch_norm_b.attrs}' + assert len(batch_norm_a.args) == len(batch_norm_b.args), f'a: {len(batch_norm_a.args)} != b: {len(batch_norm_b.args)}' + + return True + + +def test_create_reshape_op(): + X = opclass.var(name="x", shape=(16, 32, 64, 64,), dtype="float") + try: + reshape0 = opclass.Reshape(X, name="reshape_0") + assert False, "Reshape Must have attr 'newshape', Should already Fail!" + except: + pass + + reshape1 = opclass.Reshape(X, name="reshape_1", newshape=(16, 8, 128, 128)) + assert isinstance(reshape1, opclass.Reshape) + + return True + + +def test_op_extern_func(): + + # extern_func Do not need to fill 'op_name' + args = [opclass.var(name="var2", shape=(16, 128, 128), dtype="float")] + attrs = {} + extra_attrs = {} + call_dps_packed = opclass.MRT_OP_MAP[opns.CALL_DPS_PACKED]('packed_0', args, attrs, extra_attrs) + assert isinstance(call_dps_packed, sx.Symbol), 'call_dps_packed isnot a symbol' + assert call_dps_packed.op_name == opns.CALL_DPS_PACKED + return True + + +if __name__ == "__main__": + print('MRT_OP_SET as:', opclass.MRT_OP_MAP.keys()) + assert len(opclass.MRT_OP_MAP.keys()) > 0 + + assert opns.CONV2D in opclass.MRT_OP_MAP + print('MRT_OP_MAP Conv2D Class as:', opclass.MRT_OP_MAP[opns.CONV2D]) + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_op_func, test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op, test_create_reshape_op, test_op_extern_func] + for func_ in test_funcs: + rltflag = func_() + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.simple_pass.py b/tests/mir/test.simple_pass.py new file mode 100644 index 0000000..33139d4 --- /dev/null +++ b/tests/mir/test.simple_pass.py @@ -0,0 +1,149 @@ +""" +Test script for MRT Alexnet FuseDropoutPass. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_alexnet_model(): + """Get Alexnet MRT Model""" + + # Load pre-trained Alexnet + model = models.alexnet(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Alexnet to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + +def test_SimplePass_FuseDropout(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDropout Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + # print(sym) + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt>0, f'original model dropout op cnt {dropout_op_cnt} == zero!' + + # init FuseDropout Passer and execute visit + tfs : simple_pass.FuseDropoutPass = simple_pass.FuseDropoutPass(symbol) + #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) + symbol_passed = tfs.graph_visits() + + print('\n=== After FuseDropout Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + # print(sym) + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af==0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + #for sym in symdict: + # print(sym, symdict[sym]) + + #print('\n=== Back To SymList ===') + #rltlist = sx.sym2list(symdict[symbol.name]) + + return True + + +def test_SimplePass_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + conv2d_name_list = [] + def _filter_op(sym: sx.Symbol, params=None) -> sx.Symbol: + if sym.op_name == opns.CONV2D: + conv2d_name_list.append(sym.name) + return sym + + symbol_passed = tfs.custom_visits(_filter_op) + + print('\n=== After CustomFunc Pass ===') + assert len(conv2d_name_list) > 0 + print(conv2d_name_list) + rlts = sx.sym2list(symbol_passed) + + return True + + +def test_SimplePass_FuseDropout_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before FuseDropout CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt > 0, f'ori model dropout op cnt {dropout_op_cnt} == zero!' + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + def _nn_dropout(sym: sx.Symbol) -> sx.Symbol: + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + symbol_passed = tfs.custom_visits(_nn_dropout) + + print('\n=== After FuseDropout CustomFunc Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af == 0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + + print("=== Testing SymbolPass ===") + mrt_graph, mrt_params = _get_alexnet_model() + + print("Testing FuseDropoutPass for Model AlexNet") + rltflag = test_SimplePass_FuseDropout(mrt_graph, mrt_params) + print("\n" + "="*60 + "\n") + print('Passed Test1!' if rltflag else 'Test1 Failed!') + print("\n" + "="*60 + "\n") + + rltflag = test_SimplePass_CustomFunc(mrt_graph) + print("\n" + "="*60 + "\n") + print('Passed Test2!' if rltflag else 'Test2 Failed!') + print("\n" + "="*60 + "\n") + + print("Testing FuseDropout CustomFunc for Model AlexNet") + rltflag = test_SimplePass_FuseDropout_CustomFunc(mrt_graph) + print("\n" + "="*60 + "\n") + print('Passed Test3!' if rltflag else 'Test3 Failed!') + print("\n" + "="*60 + "\n") + diff --git a/tests/test.pytorch.py b/tests/test.pytorch.py index 499f66b..8af997b 100644 --- a/tests/test.pytorch.py +++ b/tests/test.pytorch.py @@ -40,8 +40,9 @@ # model inference context, like cpu, gpu, etc. config = { - "device": "cuda:0", - "target": "" } + #"device": "cuda:0", + "device": "cpu", + "target": ""} # TODO: load the model from torchvision model_name = "resnet18" # passed @@ -95,7 +96,7 @@ calibrate_repeats=16, force_run_from_trcb="Discretor", log_after_all=True, - # log_before_tr_or_cbs=[ "calibrate_run_0", ], + log_before_tr_or_cbs=[ "PrecisionRevisor", ], ): dis_tr = tr.discrete()