Skip to content

Commit f4ea637

Browse files
committed
Add cli command: list-support-ops
1 parent 4883654 commit f4ea637

File tree

6 files changed

+75
-3
lines changed

6 files changed

+75
-3
lines changed

utensor_cgen/backend/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,7 @@ def from_config(cls, config, *args, **kwargs):
118118
def from_file(cls, file_or_path, *args, **kwargs):
119119
config = parse_toml(file_or_path)[cls.TARGET][cls.COMPONENT][cls.PART]
120120
return cls(config, *args, **kwargs)
121+
122+
@property
123+
def support_ops(self):
124+
return []

utensor_cgen/backend/utensor/_backend_impl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
graph_op_lower = graph_op_lower or uTensorRearchGraphLower(config=config[uTensorRearchGraphLower.PART].to_dict())
3838
graph_alloc_lower = TensorAllocationPlanner(config=config[TensorAllocationPlanner.PART].to_dict())
3939
graph_transformer = graph_transformer or PipelineTransformer(config=config[PipelineTransformer.PART].to_dict())
40+
self._legacy_api = config['legacy-api']
4041
self._graph_op_lower = graph_op_lower
4142
self._graph_transformer = graph_transformer
4243
self._graph_alloc_lower = graph_alloc_lower
@@ -73,3 +74,11 @@ def __call__(self, ugraph):
7374
def from_file(cls, path_or_file):
7475
config = parse_toml(path_or_file)
7576
return cls(config=config)
77+
78+
@property
79+
def support_ops(self):
80+
if self._legacy_api:
81+
from .code_generator.legacy._operators import OperatorFactory
82+
else:
83+
from .code_generator.rearch._operators import OperatorFactory
84+
return OperatorFactory.support_op_types()

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_opertor(cls, op_info):
2323
op_cls = cls._operators.get((namespaces, op_type))
2424
if op_cls is None:
2525
raise OpNotSupportedError(
26-
"{}::{} not supported in utensor_cgen".format("::".join(namespaces), op_type)
26+
"{} not supported in utensor_cgen".format("::".join(list(namespaces) + [op_type]))
2727
)
2828
return op_cls(op_info)
2929

@@ -37,8 +37,11 @@ def register(cls, op_cls):
3737
@classmethod
3838
def support_op_types(cls):
3939
"""Return the set of all supported ops
40-
"""
41-
return set(cls._operators.keys())
40+
"""
41+
return set([
42+
"::".join(list(namespaces) + [op_type])
43+
for namespaces, op_type in cls._operators.keys()
44+
])
4245

4346
@classmethod
4447
def is_supported(cls, op_type):

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,37 @@ class _CommonParams(_Operator):
329329
}
330330
_ACTIVATION_STR_PATTERN = re.compile(r'^(\d+) \(\w+\)$')
331331

332+
@OperatorFactory.register
333+
class _Conv2dOperator(_CommonParams):
334+
op_type = 'Conv2dOperator'
335+
336+
@classmethod
337+
@must_return_type(Hashable)
338+
def get_constructor_parameters(cls, op_info):
339+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
340+
stride_width = op_info.op_attr['StrideW']
341+
stride_hight = op_info.op_attr['StrideH']
342+
return (
343+
_c_arr_str([1, stride_hight, stride_width, 1]),
344+
padding,
345+
)
346+
347+
def get_declare_snippet(self, op_var_name, tensor_var_map):
348+
return DeclareOpSnippet(
349+
op=self,
350+
templ_dtypes=[self.out_dtypes[0]],
351+
op_var_name=op_var_name,
352+
nested_namespaces=self.namespaces,
353+
)
354+
355+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
356+
return Conv2dOpEvalSnippet(
357+
op_info=op_info,
358+
templ_dtypes=[self.out_dtypes[0]],
359+
op_name=op_var_name,
360+
tensor_var_map=tensor_var_map,
361+
nested_namespaces=self.namespaces,
362+
)
332363

333364
@OperatorFactory.register
334365
class _QuantDWSConvOperator(_CommonParams):

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"DeclareRomTensorSnippet",
1111
"DeclareRamTensorSnippet",
1212
"DeclareOpSnippet",
13+
"Conv2dOpEvalSnippet",
1314
"DepthwiseSeperateConvOpEvalSnippet",
1415
"QuantDepthwiseSeperateConvOpEvalSnippet",
1516
"AddOpEvalSnippet",
@@ -133,6 +134,11 @@ def __init__(self, op_info, templ_dtypes, op_name, tensor_var_map, nested_namesp
133134
self.template_vars['output_map'] = output_map
134135

135136

137+
class Conv2dOpEvalSnippet(OpEvalSnippet):
138+
__inputs__ = ["in", "filter"]
139+
__outputs__ = ["out"]
140+
141+
136142
class DepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
137143
__inputs__ = ["in", "depthwise_filter", "pointwise_filter"]
138144
__outputs__ = ["out"]

utensor_cgen/cli/backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
from pprint import pformat
3+
14
import click
25

36
from utensor_cgen import __version__
@@ -38,6 +41,22 @@ def list_trans_methods(verbose):
3841
)
3942
return 0
4043

44+
@cli.command(name='list-support-ops', help='list all supported op in the backend')
45+
@click.help_option('-h', '--help')
46+
@click.option('--target', default='utensor')
47+
@click.option('--config', default='utensor_cli.toml')
48+
def list_support_ops(target, config):
49+
from utensor_cgen.backend.api import BackendManager
50+
if os.path.exists(config):
51+
backend = BackendManager.get_backend(target).from_file(config)
52+
else:
53+
backend = BackendManager.get_backend(target)({})
54+
click.secho(
55+
pformat(backend.support_ops),
56+
fg='white',
57+
bold=True
58+
)
59+
4160
@cli.command(name='generate-config', help='generate config toml file')
4261
@click.help_option('-h', '--help')
4362
@click.option('--target', required=True, help='target framework/platform')

0 commit comments

Comments
 (0)