@@ -48,21 +48,48 @@ class OptypeRenameManager(object):
4848 def get_new_optype (cls , op_type ):
4949 return cls .NAME_MAP .get (op_type , op_type )
5050
51- class CheckQuantization (object ):
51+ class CodgenAttributes (object ):
5252
5353 @classmethod
5454 def apply (cls , ugraph ):
55+ # TODO: better abstraction, sth like lowering strategy
56+ for op_info in ugraph .get_ops_by_type ("AddOperator" ):
57+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
58+ for op_info in ugraph .get_ops_by_type ("ReshapeOperator" ):
59+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
60+ for op_info in ugraph .get_ops_by_type ("MatrixMultOperator" ):
61+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
62+ for op_info in ugraph .get_ops_by_type ('ArgMinOperator' ):
63+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
64+ for op_info in ugraph .get_ops_by_type ('ArgMaxOperator' ):
65+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
66+ for op_info in ugraph .get_ops_by_type ('QuantizeOperator' ):
67+ op_info .code_gen_attributes ['namespaces' ] = ('TflmSymQuantOps' ,)
68+ for op_info in ugraph .get_ops_by_type ('DequantizeOperator' ):
69+ op_info .code_gen_attributes ['namespaces' ] = ('TflmSymQuantOps' ,)
70+ for op_info in ugraph .get_ops_by_type ('ReLUOperator' ):
71+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
72+ for op_info in ugraph .get_ops_by_type ('ReLU6Operator' ):
73+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
74+ for op_info in ugraph .get_ops_by_type ('MinOperator' ):
75+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
76+ for op_info in ugraph .get_ops_by_type ('MaxOperator' ):
77+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
78+ for op_info in ugraph .get_ops_by_type ('MaxPoolOperator' ):
79+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
80+ for op_info in ugraph .get_ops_by_type ('MinPoolOperator' ):
81+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
82+ for op_info in ugraph .get_ops_by_type ('Conv2dOperator' ):
83+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
5584 for op_info in ugraph .get_ops_by_type ('DepthwiseSeparableConvOperator' ):
5685 if cls ._check_quantized (op_info ):
57- op_info .op_type = 'QuantizedDepthwiseSeparableConvOperator'
86+ op_info .code_gen_attributes ['namespaces' ] = ('TflmSymQuantOps' ,)
87+ else :
88+ op_info .code_gen_attributes ['namespaces' ] = ('ReferenceOperators' ,)
5889 for op_info in ugraph .get_ops_by_type ('FullyConnectedOperator' ):
5990 if cls ._check_quantized (op_info ):
60- op_info .op_type = 'QuantizedFullyConnectedOperator'
61- for op_info in ugraph .get_ops_by_type ('DequantizeOperator' ):
62- op_info .code_gen_attributes ['namespaces' ] = ('TFLM' ,)
63- for op_info in ugraph .get_ops_by_type ('QuantizeOperator' ):
64- op_info .code_gen_attributes ['namespaces' ] = ('TFLM' ,)
65-
91+ op_info .code_gen_attributes ['namespaces' ] = ('TflmSymQuantOps' ,)
92+
6693 @classmethod
6794 def _check_quantized (cls , op_info ):
6895 for tensor_info in chain (
@@ -83,7 +110,7 @@ def handle_tensorflow(self, ugraph):
83110 op_info .op_type = self .OptypeRenameManager .get_new_optype (op_info .op_type )
84111
85112 def handle_tflite (self , ugraph ):
86- self .CheckQuantization .apply (ugraph )
113+ self .CodgenAttributes .apply (ugraph )
87114
88115 @class_property
89116 def default_config (cls ):
0 commit comments