-
Notifications
You must be signed in to change notification settings - Fork 0
Sym pass2 #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
| from . import symbol | ||
|
|
||
| #from mrt.mir.mhsymbol import MultiHeadSymbol, Graph | ||
| class MultiHeadSymbol(dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MultiHeadSymbol should inherit from Symbol class, and the Graph is exactly MultiHeadSymbol.
| @property | ||
| def args(self): | ||
| return self.graph.args | ||
| @property | ||
| def op_name(self): | ||
| return self.graph.op_name | ||
| @property | ||
| def name(self): | ||
| return self.graph.name | ||
| @property | ||
| def shape(self): | ||
| return self.graph.shape | ||
| @property | ||
| def dtype(self): | ||
| return self.graph.dtype | ||
| @property | ||
| def attrs(self): | ||
| return self.graph.attrs | ||
| @property | ||
| def extra_attrs(self): | ||
| return self.graph.extra_attrs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add return type hint
| class SymbolBridge: # SymbolManipulator / Pass | ||
| graph: Symbol | ||
|
|
||
| def __init__(self, symbol: Symbol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataclass will create __init__ function automatically, which may lead to conflict and unknown bugs.
don't use dataclass for SymbolBridge and subclass.
| args = data['args'] or [] | ||
| attrs = data['attrs'] or {} | ||
| try: | ||
| out = cls(*args, name=data['name'], op_name=data['op_name'], extra_attrs=data['extra_attrs'], **attrs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subclass like Conv2D don't have op_name parameter, may lead to error.
Since we should have uniform construct API, make subclass like Conv2D to add a op_name parameter, and that sounds more reasonable. But subclass has intrinsic property for op_name, so use assert to check op_name is None or correct name, like:
def __init__(..., op_name=None, ...):
assert op_name is None or op_name == opns.CONV2D
| _format_printer(oattrs)) | ||
|
|
||
|
|
||
| @dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since symbol is not the dataclass, the below serialization function: dump_json and load_json may not work. Check these serial functions carefully.
|
|
||
| rs_bit = X.from_const_data(X.precision - anno_bit) | ||
| X = op.right_shift(X, rs_bit).like(self) | ||
| X_op = infer_single(opclass.right_shift(X.graph, rs_bit)).like(self.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since all transformers have been moved to SymbolBridge, the like method in Symbol should also move to transformer. Note that like method is used to create a similar symbol bridge for new symbol, which is identical to from_symbol in SymbolBridge.
|
Does this PR passed for |
With 16 calibrate repeats, and 200 eval round: |
[mrt ir]: opclass and WithParameters