Skip to content

Commit 2ac0ee4

Browse files
authored
Merge pull request #112 from uTensor/re-arch-namespaces
update supported ops namespaces
2 parents 42d5885 + 29c60a3 commit 2ac0ee4

File tree

4 files changed

+99
-28
lines changed

4 files changed

+99
-28
lines changed

utensor_cgen/backend/utensor/_graph_lower/_op_lower.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,48 @@ class OptypeRenameManager(object):
4848
def get_new_optype(cls, op_type):
4949
return cls.NAME_MAP.get(op_type, op_type)
5050

51-
class CheckQuantization(object):
51+
class CodgenAttributes(object):
5252

5353
@classmethod
5454
def apply(cls, ugraph):
55+
# TODO: better abstraction, sth like lowering strategy
56+
for op_info in ugraph.get_ops_by_type("AddOperator"):
57+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
58+
for op_info in ugraph.get_ops_by_type("ReshapeOperator"):
59+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
60+
for op_info in ugraph.get_ops_by_type("MatrixMultOperator"):
61+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
62+
for op_info in ugraph.get_ops_by_type('ArgMinOperator'):
63+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
64+
for op_info in ugraph.get_ops_by_type('ArgMaxOperator'):
65+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
66+
for op_info in ugraph.get_ops_by_type('QuantizeOperator'):
67+
op_info.code_gen_attributes['namespaces'] = ('TflmSymQuantOps',)
68+
for op_info in ugraph.get_ops_by_type('DequantizeOperator'):
69+
op_info.code_gen_attributes['namespaces'] = ('TflmSymQuantOps',)
70+
for op_info in ugraph.get_ops_by_type('ReLUOperator'):
71+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
72+
for op_info in ugraph.get_ops_by_type('ReLU6Operator'):
73+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
74+
for op_info in ugraph.get_ops_by_type('MinOperator'):
75+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
76+
for op_info in ugraph.get_ops_by_type('MaxOperator'):
77+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
78+
for op_info in ugraph.get_ops_by_type('MaxPoolOperator'):
79+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
80+
for op_info in ugraph.get_ops_by_type('MinPoolOperator'):
81+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
82+
for op_info in ugraph.get_ops_by_type('Conv2dOperator'):
83+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
5584
for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'):
5685
if cls._check_quantized(op_info):
57-
op_info.op_type = 'QuantizedDepthwiseSeparableConvOperator'
86+
op_info.code_gen_attributes['namespaces'] = ('TflmSymQuantOps',)
87+
else:
88+
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
5889
for op_info in ugraph.get_ops_by_type('FullyConnectedOperator'):
5990
if cls._check_quantized(op_info):
60-
op_info.op_type = 'QuantizedFullyConnectedOperator'
61-
for op_info in ugraph.get_ops_by_type('DequantizeOperator'):
62-
op_info.code_gen_attributes['namespaces'] = ('TFLM',)
63-
for op_info in ugraph.get_ops_by_type('QuantizeOperator'):
64-
op_info.code_gen_attributes['namespaces'] = ('TFLM',)
65-
91+
op_info.code_gen_attributes['namespaces'] = ('TflmSymQuantOps',)
92+
6693
@classmethod
6794
def _check_quantized(cls, op_info):
6895
for tensor_info in chain(
@@ -83,7 +110,7 @@ def handle_tensorflow(self, ugraph):
83110
op_info.op_type = self.OptypeRenameManager.get_new_optype(op_info.op_type)
84111

85112
def handle_tflite(self, ugraph):
86-
self.CheckQuantization.apply(ugraph)
113+
self.CodgenAttributes.apply(ugraph)
87114

88115
@class_property
89116
def default_config(cls):

0 commit comments

Comments
 (0)