Skip to content
21 changes: 15 additions & 6 deletions python/mrt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .runtime.analysis import *

from .mir import op, helper
# from .mir.model import MultiHeadSymbol
from .mir.mhsymbol import MultiHeadSymbol
from .mir.symbol import *

from .dataset.base import Dataset
Expand All @@ -26,7 +26,7 @@

from .quantization.discrete import Discretor
from .quantization.precision import PrecisionRevisor
from .quantization.transform import TransformerT
from .mir.symbol_pass import SymTransformerT

@dataclass
class TraceConfig(config._BaseConfig):
Expand Down Expand Up @@ -174,7 +174,7 @@ def _new(self, tr_name: str,
_stat_type = self._stat_type)

def checkpoint_run(self,
*callbacks: typing.List[TransformerT],
*callbacks: typing.List[SymTransformerT],
tr_name: typing.Optional[str] = None,
**kwargs) -> Trace:
C = TraceConfig.G()
Expand All @@ -200,7 +200,7 @@ def checkpoint_run(self,
for cb in callbacks:
# deep copy params to avoid conflict status
params = {k: v for k, v in out.params.items()}
print("Apply Trace: {:25} Transformer: {}".format(
print("Apply Trace: {:25} SymbolTransformer: {}".format(
tr_name, cb.__name__))

if cb.__name__ in C.log_before_tr_or_cbs:
Expand All @@ -223,7 +223,14 @@ def checkpoint_run(self,

def discrete(self) -> Trace:
fuse_tr = self.fuse()

"""Must pass params inside a dict,
Cause it will be unfolded separately
"""
seg_tr = fuse_tr.checkpoint_run(seg.Spliter.get_transformer())
kwargs_seg = {"ptr": {"head": seg_tr.symbol.extra_attrs.get("head"),
"head_params": seg_tr.symbol.extra_attrs.get("head_params"),
"seg_names": seg_tr.symbol.extra_attrs.get("seg_names")}}

C = TraceConfig.G()
calib_tr = seg_tr.calibrate(
Expand All @@ -232,7 +239,8 @@ def discrete(self) -> Trace:
quant_tr = calib_tr.quantize()
quant_tr = quant_tr.checkpoint_run(
seg.Merger.get_transformer(),
spliter=seg_tr.symbol)
spliter=seg_tr.symbol,
**kwargs_seg)
return quant_tr

def fuse(self, **kwargs) -> Trace:
Expand All @@ -247,20 +255,21 @@ def fuse(self, **kwargs) -> Trace:
fuse.FuseDropout.get_transformer(),
fuse.FuseMean.get_transformer(),
fuse.FuseNaiveSoftmax.get_transformer(),
fuse.FuseIdentity.get_transformer(),
fuse.FuseConstant.get_transformer(),
**kwargs,
)

def calibrate(self, repeats: int = 1, **kwargs) -> Trace:
assert self._dataset is not None
tr_name = kwargs.pop("tr_name", "calibrate")

out = self
for i in range(repeats):
data, _ = self._dataset.next()
out = out.checkpoint_run(
calib.Calibrator.get_transformer(),
data = data,
# tr_name = tr_name,
tr_name = f"{tr_name}_run_{i}",
**kwargs)
out = out.checkpoint_run(
Expand Down
1 change: 1 addition & 0 deletions python/mrt/frontend/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps

from mrt.mir.symbol import *
from mrt.mir.mhsymbol import MultiHeadSymbol, Graph
from mrt.common.types import *
from mrt.common.config import MRTConfig

Expand Down
14 changes: 7 additions & 7 deletions python/mrt/frontend/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..symbol import *
from ..types import *
from .. import op
from .. import opclass

__ALL__ = [ "expr2symbol", "symbol2expr", "tvm_type_infer" ]

Expand Down Expand Up @@ -62,7 +63,7 @@ def _cast_expr(node: RelayExpr):
if isinstance(node, relay.expr.Constant):
name = N.n("const_")
params[name] = node.data
symbol_map[node] = op.variable(name,
symbol_map[node] = opclass.var(name,
node.data.shape, node.data.dtype)
return

Expand All @@ -85,11 +86,11 @@ def _cast_expr(node: RelayExpr):

if isinstance(node, relay.expr.Var):
name = node.name_hint or N.n(prefix="input_")
symbol_map[node] = op.variable(name, shape, dtype)
symbol_map[node] = opclass.var(name, shape, dtype)
elif isinstance(node, relay.expr.If):
args = [ node.cond, node.true_branch, node.false_branch ]
args = [symbol_map[i] for i in args]
symbol_map[node] = op._new_op(IF, *args, **attrs)
symbol_map[node] = opclass.extern_opfunc(IF)(*args, **attrs)
elif isinstance(node, relay.expr.Call):
op_name = node.op.name
if op_name in [CONCAT, ADV_INDEX]:
Expand All @@ -108,15 +109,14 @@ def _cast_expr(node: RelayExpr):
attrs.pop("dtype")
elif op_name == GET_VALID_COUNT:
attrs.pop("score_threshold")
symbol_map[node] = op._new_op(op_name, *args, **attrs)
symbol_map[node] = opclass.extern_opfunc(op_name)(*args, **attrs)
elif isinstance(node, relay.TupleGetItem):
args = [ symbol_map[node.tuple_value], ]
attrs['index'] = node.index
symbol_map[node] = op._new_op(
TUPLE_GET_ITEM, *args, **attrs)
symbol_map[node] = opclass.extern_opfunc(TUPLE_GET_ITEM)(*args, **attrs)
elif isinstance(node, relay.Tuple):
args = [ symbol_map[f] for f in node.fields ]
symbol_map[node] = op._new_op(TUPLE, *args, **attrs)
symbol_map[node] = opclass.extern_opfunc(TUPLE)(*args, **attrs)
else:
raise RuntimeError(
"MRT not support expr type:{}".format(type(node)))
Expand Down
25 changes: 14 additions & 11 deletions python/mrt/frontend/pytorch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import torch.nn.functional as F
import sys

from mrt.mir.symbol import Symbol, MultiHeadSymbol, sym2list, transform
from mrt.mir import op
from mrt.mir.symbol import Symbol, sym2list, transform
from mrt.mir.mhsymbol import MultiHeadSymbol
from mrt.mir import op, opclass
from mrt.mir.opns import *
from mrt.common.types import ParametersT
from mrt.common.utils import N
Expand Down Expand Up @@ -46,7 +47,7 @@ class _T:

"adaptive_avg_pool2d.default": _T(ADAPTIVE_AVG_POOL2D, 1, [ Attr("output_size", (1,1)) ]),
"max_pool2d.default": _T(MAX_POOL2D, 1, [
Attr("kernel_size", (1,1)), Attr("strides", (1,1)), Attr("padding", (0,0)) ]),
Attr("kernel_size", (1,1)), Attr("strides", (1,1)), Attr("padding", (0,0)), Attr("dilation", (1,1)), Attr("ceil_mode", False) ]),
"mean.dim": _T(MEAN, 1, [ Attr("dim", None), Attr("keepdim", False) ]),

"add.Tensor": _T(ADD, 2), "add_.Tensor": _T(ADD, 2),
Expand All @@ -60,7 +61,7 @@ class _T:
"cat.default": _T(CONCAT, 1, [ Attr("dim", 0) ]),
"view.default": _T(RESHAPE, 1, [ Attr("shape", ()) ]),
"transpose.int": _T(TRANSPOSE, 1, [ Attr("dim0", 0), Attr("dim1", 0) ]),
"contiguous.default": _T(PASS, 1),
"contiguous.default": _T(IDENTITY, 1),

"chunk.default": _T(SPLIT, 1, [ Attr("chunks", 1), Attr("dim", 0) ]),
"getitem": _T(TUPLE_GET_ITEM, 1, [ Attr("index", 0) ]),
Expand Down Expand Up @@ -100,7 +101,7 @@ class _T:
),
RESHAPE: torch.reshape,
TRANSPOSE: torch.transpose,
PASS: lambda x: x,
IDENTITY: lambda x: x,
SPLIT: torch.chunk,

ADD: torch.add,
Expand Down Expand Up @@ -156,7 +157,7 @@ def create_parameters(ep: torch.export.ExportedProgram):
dshape = data_to_mrt(torch_shape)
dtype = data_to_mrt(torch_dtype)

out = op.variable(name_hint, dshape, dtype)
out = opclass.var(name=name_hint, shape=dshape, dtype=dtype)
params[name_hint] = to_bind_parameters[spec.target].detach().numpy().astype(dtype)
assert dshape == list(params[name_hint].shape)
# print(">> vars: ", out)
Expand Down Expand Up @@ -207,7 +208,7 @@ def _retrieve_args(node):
continue

if node.name not in param_vars: # input
env[node] = op.variable(node.name, shape, dtype)
env[node] = opclass.var(name=node.name, shape=shape, dtype=dtype)
else:
env[node] = param_vars[node.name]
elif node.op == "output": # [[ out1, out2, out3 ]]
Expand All @@ -234,13 +235,15 @@ def _retrieve_args(node):
if mapper.op_name == CONCAT:
args = args[0]

if mapper.op_name == SPLIT:
shape = data_to_mrt([ t.shape for t in node.meta['val']])
dtype = data_to_mrt([ t.dtype for t in node.meta['val']])

if mapper.op_name == TUPLE_GET_ITEM and args[0].op_name == BATCH_NORM:
out = args[0]
else:
out = op._new_op(
mapper.op_name, *args,
name=node.name, extra_attrs={ "shape": shape, "dtype": dtype },
**attrs)
out = opclass.extern_opfunc(mapper.op_name)(*args, name=node.name,
extra_attrs={"shape": shape, "dtype": dtype}, **attrs)
env[node] = out
else:
raise ValueError(f"Unsupported op {node.op}")
Expand Down
1 change: 1 addition & 0 deletions python/mrt/frontend/pytorch/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .types import *

from mrt.mir.symbol import *
from mrt.mir.mhsymbol import MultiHeadSymbol
from mrt.common.types import *

Executor = namedtuple("Executor", ["vm", "device"])
Expand Down
35 changes: 35 additions & 0 deletions python/mrt/mir/mhsymbol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import typing

from mrt.common.utils import *
from mrt.common.types import *

from . import opns, opclass, optype
from . import symbol

#from mrt.mir.mhsymbol import MultiHeadSymbol, Graph
class MultiHeadSymbol(dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiHeadSymbol should inherit from Symbol class, and the Graph is exactly MultiHeadSymbol.

""" { "main": F(X) } """
origin: typing.Optional[symbol.Symbol] = None

@classmethod
def from_symbol(cls, symbol: symbol.Symbol, name: str = "main"):
return MultiHeadSymbol({ name: symbol })

def as_tuple(self) -> typing.Tuple[typing.List[str], symbol.Symbol]:
# args = list(self.values())
# sym_type = type(args[0]) if args else Symbol
mhs = self.origin or optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*list(self.values())))
return list(self.keys()), mhs

@classmethod
def from_tuple(cls, tuple_names, symbol):
assert symbol.is_op(opns.TUPLE), symbol
mhs = cls(zip(tuple_names, symbol.args))
mhs.origin = symbol
return mhs

Graph = typing.Union[symbol.Symbol, MultiHeadSymbol]
""" Notice that Symbol and MultiHeadSymbol can both
be regarded as a model Graph.
"""

Loading