Skip to content

Commit 199247b

Browse files
committed
Merge branch 're-arch-support' into tutorials
2 parents dbc4858 + 421f0a9 commit 199247b

File tree

3 files changed

+22
-18
lines changed

3 files changed

+22
-18
lines changed

utensor_cgen/backend/utensor/_graph_lower/_op_lower.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@ def default_config(cls):
3636
class uTensorRearchGraphLower(uTensorGraphLowerBase):
3737
PART = 'rearch_graph_lower'
3838

39-
def __init__(self, config):
40-
final_config = Configuration(self.default_config, config)
41-
self.tflite_use_quant_dws_conv = final_config['tflite_use_quant_dws_conv']
42-
4339
class OptypeRenameManager(object):
4440
NAME_MAP = {
4541
'Add': 'AddOperator',
@@ -51,12 +47,23 @@ class OptypeRenameManager(object):
5147
def get_new_optype(cls, op_type):
5248
return cls.NAME_MAP.get(op_type, op_type)
5349

54-
class AddCodegenAttributes(object):
50+
class CheckQuantization(object):
5551

5652
@classmethod
57-
def add_attributes(cls, ugraph):
58-
for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'):
59-
op_info.code_gen_attributes['namespaces'] = ('TFLM',)
53+
def apply(cls, ugraph):
54+
if cls._check_quantized(ugraph):
55+
for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'):
56+
op_info.op_type = 'QuantizedDepthwiseSeparableConvOperator'
57+
for op_info in ugraph.get_ops_by_type('FullyConnectedOperator'):
58+
op_info.op_type = 'QuantizedFullyConnectedOperator'
59+
60+
@classmethod
61+
def _check_quantized(cls, ugraph):
62+
for op_info in ugraph.ops_info.values():
63+
for tensor_info in op_info.output_tensors:
64+
# FIXME: better way to check quantization
65+
if 'quantization_zeros' in tensor_info.attributes:
66+
return True
6067

6168
@classmethod
6269
def add_name_map(cls, generic_name, target_specific_name):
@@ -67,11 +74,8 @@ def handle_tensorflow(self, ugraph):
6774
op_info.op_type = self.OptypeRenameManager.get_new_optype(op_info.op_type)
6875

6976
def handle_tflite(self, ugraph):
70-
if self.tflite_use_quant_dws_conv:
71-
self.AddCodegenAttributes.add_attributes(ugraph)
77+
self.CheckQuantization.apply(ugraph)
7278

7379
@class_property
7480
def default_config(cls):
75-
return {
76-
'tflite_use_quant_dws_conv': True,
77-
}
81+
return {}

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ class OperatorFactory(object):
1919
@classmethod
2020
def get_opertor(cls, op_info):
2121
op_type = op_info.op_type
22-
namespaces = op_info.code_gen_attributes.get('namespaces', tuple())
23-
op_cls = cls._operators.get((namespaces, op_type))
22+
codegen_namespaces = op_info.code_gen_attributes.get('namespaces', tuple())
23+
op_cls = cls._operators.get((codegen_namespaces, op_type))
2424
if op_cls is None:
2525
raise OpNotSupportedError(
26-
"{} not supported in utensor_cgen".format("::".join(list(namespaces) + [op_type]))
26+
"{} not supported in utensor_cgen".format("::".join(list(codegen_namespaces) + [op_type]))
2727
)
2828
return op_cls(op_info)
2929

utensor_cgen/legalizer/tflite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def legalize_dtype(self, ugraph):
2424

2525
class _OpTypeRename(object):
2626
_OPTYPE_RENAME_MAP = {
27-
"FullyConnected": "QuantizedFullyConnectedOperator",
27+
"FullyConnected": "FullyConnectedOperator",
2828
"Quantize": "QuantizeOperator",
29-
"DepthwiseConv2d": "QuantizedDepthwiseSeparableConvOperator",
29+
"DepthwiseConv2d": "DepthwiseSeparableConvOperator",
3030
"MaxPool2d": "MaxPoolOperator",
3131
"Dequantize": "DequantizeOperator",
3232
"Reshape": "ReshapeOperator",

0 commit comments

Comments
 (0)