Skip to content

Commit 421f0a9

Browse files
committed
Fix legalizer and lower: it now does not assuming the graph is quantized and lower will take care of it once if is quantized
1 parent 93536ba commit 421f0a9

File tree

3 files changed

+22
-18
lines changed

3 files changed

+22
-18
lines changed

utensor_cgen/backend/utensor/_graph_lower/_op_lower.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def default_config(cls):
3434
class uTensorRearchGraphLower(uTensorGraphLowerBase):
3535
PART = 'rearch_graph_lower'
3636

37-
def __init__(self, config):
38-
final_config = Configuration(self.default_config, config)
39-
self.tflite_use_quant_dws_conv = final_config['tflite_use_quant_dws_conv']
40-
4137
class OptypeRenameManager(object):
4238
NAME_MAP = {
4339
'Add': 'AddOperator',
@@ -49,12 +45,23 @@ class OptypeRenameManager(object):
4945
def get_new_optype(cls, op_type):
5046
return cls.NAME_MAP.get(op_type, op_type)
5147

52-
class AddCodegenAttributes(object):
48+
class CheckQuantization(object):
5349

5450
@classmethod
55-
def add_attributes(cls, ugraph):
56-
for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'):
57-
op_info.code_gen_attributes['namespaces'] = ('TFLM',)
51+
def apply(cls, ugraph):
52+
if cls._check_quantized(ugraph):
53+
for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'):
54+
op_info.op_type = 'QuantizedDepthwiseSeparableConvOperator'
55+
for op_info in ugraph.get_ops_by_type('FullyConnectedOperator'):
56+
op_info.op_type = 'QuantizedFullyConnectedOperator'
57+
58+
@classmethod
59+
def _check_quantized(cls, ugraph):
60+
for op_info in ugraph.ops_info.values():
61+
for tensor_info in op_info.output_tensors:
62+
# FIXME: better way to check quantization
63+
if 'quantization_zeros' in tensor_info.attributes:
64+
return True
5865

5966
@classmethod
6067
def add_name_map(cls, generic_name, target_specific_name):
@@ -65,11 +72,8 @@ def handle_tensorflow(self, ugraph):
6572
op_info.op_type = self.OptypeRenameManager.get_new_optype(op_info.op_type)
6673

6774
def handle_tflite(self, ugraph):
68-
if self.tflite_use_quant_dws_conv:
69-
self.AddCodegenAttributes.add_attributes(ugraph)
75+
self.CheckQuantization.apply(ugraph)
7076

7177
@class_property
7278
def default_config(cls):
73-
return {
74-
'tflite_use_quant_dws_conv': True,
75-
}
79+
return {}

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ class OperatorFactory(object):
1919
@classmethod
2020
def get_opertor(cls, op_info):
2121
op_type = op_info.op_type
22-
namespaces = op_info.code_gen_attributes.get('namespaces', tuple())
23-
op_cls = cls._operators.get((namespaces, op_type))
22+
codegen_namespaces = op_info.code_gen_attributes.get('namespaces', tuple())
23+
op_cls = cls._operators.get((codegen_namespaces, op_type))
2424
if op_cls is None:
2525
raise OpNotSupportedError(
26-
"{} not supported in utensor_cgen".format("::".join(list(namespaces) + [op_type]))
26+
"{} not supported in utensor_cgen".format("::".join(list(codegen_namespaces) + [op_type]))
2727
)
2828
return op_cls(op_info)
2929

utensor_cgen/legalizer/tflite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def legalize_dtype(self, ugraph):
2424

2525
class _OpTypeRename(object):
2626
_OPTYPE_RENAME_MAP = {
27-
"FullyConnected": "QuantizedFullyConnectedOperator",
27+
"FullyConnected": "FullyConnectedOperator",
2828
"Quantize": "QuantizeOperator",
29-
"DepthwiseConv2d": "QuantizedDepthwiseSeparableConvOperator",
29+
"DepthwiseConv2d": "DepthwiseSeparableConvOperator",
3030
"MaxPool2d": "MaxPoolOperator",
3131
"Dequantize": "DequantizeOperator",
3232
"Reshape": "ReshapeOperator",

0 commit comments

Comments
 (0)