Skip to content

Commit 373d1a7

Browse files
committed
Add constructor parameters: qFC
1 parent 3673d0f commit 373d1a7

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
312312
)
313313

314314

315-
class _DWSConvOps(_Operator):
315+
class _CommonParams(_Operator):
316316
_PADDING_MAP = {
317317
0: "UNKNOWN",
318318
1: "VALID",
@@ -331,7 +331,7 @@ class _DWSConvOps(_Operator):
331331

332332

333333
@OperatorFactory.register
334-
class _QuantDWSConvOperator(_DWSConvOps):
334+
class _QuantDWSConvOperator(_CommonParams):
335335
op_type = "QuantizedDepthwiseSeparableConvOperator"
336336

337337
@classmethod
@@ -374,7 +374,7 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
374374

375375

376376
@OperatorFactory.register
377-
class _DWSConvOperator(_DWSConvOps):
377+
class _DWSConvOperator(_CommonParams):
378378
op_type = "DepthwiseSeparableConvOperator"
379379

380380
@classmethod
@@ -407,9 +407,18 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
407407

408408

409409
@OperatorFactory.register
410-
class _QuantizedFullyConnectedOperator(_Operator):
410+
class _QuantizedFullyConnectedOperator(_CommonParams):
411411
op_type = "QuantizedFullyConnectedOperator"
412412

413+
@classmethod
414+
@must_return_type(Hashable)
415+
def get_constructor_parameters(cls, op_info):
416+
activation_idx = cls._ACTIVATION_STR_PATTERN.match(
417+
op_info.op_attr['FusedActivationFunction']
418+
).group(1)
419+
activation = cls._ACTIVATION_MAP[activation_idx]
420+
return (activation,)
421+
413422
def get_declare_snippet(self, op_var_name, tensor_var_map):
414423
return DeclareOpSnippet(
415424
op=self,

utensor_cgen/frontend/tflite.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ def _build_param_ops(self, fb_model, ugraph):
290290
buffer_array = fb_model.Buffers(buffer_index).DataAsNumpy()
291291
if isinstance(buffer_array, int):
292292
continue # somehow, sometimes, the buffer contains no data, likely to be an intermediate tensor
293-
buffer_content = fb_model.Buffers(buffer_index).DataAsNumpy().view(dtype)
293+
buffer_content = fb_model.Buffers(buffer_index).DataAsNumpy().view(dtype).reshape(
294+
self.tensor_names_map[idx].shape
295+
)
294296

295297
OperationInfo(
296298
name=node_name,
@@ -343,8 +345,8 @@ def _build_intermediate_ops(self, fb_model, ugraph):
343345
op = subgraph.Operators(i)
344346
local_op_code = op.OpcodeIndex()
345347
global_op_code = fb_model.OperatorCodes(local_op_code)
346-
builtinOperator_code = global_op_code.BuiltinCode()
347-
op_type = self._BUILTIN_OPS[builtinOperator_code]
348+
builtin_op_code = global_op_code.BuiltinCode()
349+
op_type = self._BUILTIN_OPS[builtin_op_code]
348350

349351
node_name = str(i) + "_" + op_type
350352

utensor_cgen/legalizer/tflite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def apply(cls, ugraph):
5050
for op_info in ugraph.get_ops_by_type('FullyConnected'):
5151
filter_tensor = op_info.input_tensors[1]
5252
filter_op = filter_tensor.op
53-
np_arr = filter_op.op_attr['value'].value.np_array.reshape(filter_tensor.shape)
53+
np_arr = filter_op.op_attr['value'].value.np_array
5454
filter_op.op_attr['value'].value.np_array = np_arr.T
5555
filter_tensor.shape = list(np_arr.T.shape)
5656
filter_op.output_tensors[0].shape = list(np_arr.T.shape)

0 commit comments

Comments
 (0)