|
1 | 1 | from copy import deepcopy |
| 2 | +from itertools import chain |
2 | 3 |
|
3 | 4 | from utensor_cgen.backend.base import BackendPart |
4 | 5 | from utensor_cgen.logger import logger |
@@ -49,23 +50,27 @@ class CheckQuantization(object): |
49 | 50 |
|
50 | 51 | @classmethod |
51 | 52 | def apply(cls, ugraph): |
52 | | - if cls._check_quantized(ugraph): |
53 | | - for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'): |
| 53 | + for op_info in ugraph.get_ops_by_type('DepthwiseSeparableConvOperator'): |
| 54 | + if cls._check_quantized(op_info): |
54 | 55 | op_info.op_type = 'QuantizedDepthwiseSeparableConvOperator' |
55 | | - for op_info in ugraph.get_ops_by_type('FullyConnectedOperator'): |
| 56 | + for op_info in ugraph.get_ops_by_type('FullyConnectedOperator'): |
| 57 | + if cls._check_quantized(op_info): |
56 | 58 | op_info.op_type = 'QuantizedFullyConnectedOperator' |
57 | 59 | for op_info in ugraph.get_ops_by_type('DequantizeOperator'): |
58 | 60 | op_info.code_gen_attributes['namespaces'] = ('TFLM',) |
59 | 61 | for op_info in ugraph.get_ops_by_type('QuantizeOperator'): |
60 | 62 | op_info.code_gen_attributes['namespaces'] = ('TFLM',) |
61 | 63 |
|
62 | 64 | @classmethod |
63 | | - def _check_quantized(cls, ugraph): |
64 | | - for op_info in ugraph.ops_info.values(): |
65 | | - for tensor_info in op_info.output_tensors: |
66 | | - # FIXME: better way to check quantization |
67 | | - if 'quantization_zeros' in tensor_info.attributes: |
68 | | - return True |
| 65 | + def _check_quantized(cls, op_info): |
| 66 | + for tensor_info in chain( |
| 67 | + op_info.output_tensors, |
| 68 | + op_info.input_tensors |
| 69 | + ): |
| 70 | + # FIXME: better way to check quantization |
| 71 | + if 'quantization_zeros' in tensor_info.attributes: |
| 72 | + return True |
| 73 | + return False |
69 | 74 |
|
70 | 75 | @classmethod |
71 | 76 | def add_name_map(cls, generic_name, target_specific_name): |
|
0 commit comments