Skip to content

Commit b3d2acd

Browse files
committed
adding type information to missing ops
1 parent 8f55c17 commit b3d2acd

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,17 @@ class MissingOpEvalSnippet(OpEvalSnippet):
237237

238238
def __init__(self, op_info, tensor_var_map):
239239
Snippet.__init__(self)
240-
input_var_names = [tensor_var_map[tensor.name] for tensor in op_info.input_tensors]
241-
out_tensor_names = [tensor.name for tensor in op_info.output_tensors]
242-
out_var_names = [tensor_var_map[tensor.name] for tensor in op_info.output_tensors]
240+
243241
quant_params_map = {}
244242
for out_tensor in op_info.output_tensors:
245243
quant_params = self.get_quant_param(out_tensor)
246244
quant_params_map[out_tensor.name] = quant_params
247245
self.template_vars['op_type'] = op_info.op_type
248-
self.template_vars['input_var_names'] = input_var_names
249-
self.template_vars['out_var_names'] = out_var_names
250-
self.template_vars['out_tensor_names'] = out_tensor_names
246+
self.template_vars['input_tensors'] = op_info.input_tensors[:]
247+
self.template_vars['out_var_names'] = [
248+
tensor_var_map[tensor.name] for tensor in op_info.output_tensors
249+
]
250+
self.template_vars['output_tensors'] = op_info.output_tensors[:]
251251
self.template_vars['quant_params_map'] = quant_params_map
252252

253253

utensor_cgen/backend/utensor/snippets/templates/snippets/rearch/op_missing.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
FIXME: {{op_type}} currently not supported, you have to fill up this section or it won't compile
33
44
Input Tensors:
5-
{%for name in input_var_names%}
6-
{{name}}
5+
{%for tensor in input_tensors%}
6+
{{tensor.name}}, of type {{tensor.dtype}}
77
{%endfor%}
88
99
Output Tensors:
10-
{%for name, var_name in zip(out_tensor_names, out_var_names)%}
11-
{{name}} should be named as {{var_name}}
12-
{%if quant_params_map[name]%}
10+
{%for tensor, var_name in zip(output_tensors, out_var_names)%}
11+
{{tensor.name}} is of type {{tensor.dtype}} and should be named as {{var_name}}
12+
{%if quant_params_map[tensor.name]%}
1313
quantization parameters:
14-
- zero point: {{quant_params_map[name]["zero_point"]["value"]}}, {{quant_params_map[name]["zero_point"]["type_str"]}}
15-
- scale: {{quant_params_map[name]["scale"]["value"]}}, {{quant_params_map[name]["scale"]["type_str"]}}
16-
- is per tensor quantization: {{quant_params_map[name]["is_per_tensor"]}}
14+
- zero point: {{quant_params_map[tensor.name]["zero_point"]["value"]}}, {{quant_params_map[tensor.name]["zero_point"]["type_str"]}}
15+
- scale: {{quant_params_map[tensor.name]["scale"]["value"]}}, {{quant_params_map[tensor.name]["scale"]["type_str"]}}
16+
- is per tensor quantization: {{quant_params_map[tensor.name]["is_per_tensor"]}}
1717
{%endif%}
1818
{%endfor%}
1919
*/

0 commit comments

Comments
 (0)