Skip to content

Commit 93b4e1f

Browse files
committed
Add BatchNorm, Mul, Add, Sub, Mean, and more ops
1 parent 9c8cc41 commit 93b4e1f

File tree

5 files changed

+345
-13
lines changed

5 files changed

+345
-13
lines changed

utensor_cgen/backend/utensor/_graph_lower/_op_lower.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(self, config):
4141
class OptypeRenameManager(object):
4242
NAME_MAP = {
4343
'Add': 'AddOperator',
44+
'Mul': 'MulOperator',
45+
'Sub': 'SubOperator',
4446
'Conv2D': 'ConvOperator',
4547
'MatMul': 'MatrixMultOperator'
4648
}

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
436436
tensor_var_map=tensor_var_map,
437437
)
438438

439+
@OperatorFactory.register
440+
class _ConvOperator(_CommonParams):
441+
op_type = "ConvOperator"
442+
443+
@classmethod
444+
@must_return_type(Hashable)
445+
def get_constructor_parameters(cls, op_info):
446+
447+
strides = [
448+
1,
449+
op_info.op_attr['StrideW'],
450+
op_info.op_attr['StrideH'],
451+
1,
452+
]
453+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
454+
strides_str = ','.join(map(str, strides))
455+
return ("{{ {} }}".format(strides_str), padding)
456+
457+
def get_declare_snippet(self, op_var_name, tensor_var_map):
458+
return DeclareOpSnippet(
459+
op=self,
460+
templ_dtypes=[self.out_dtypes[0]],
461+
op_var_name=op_var_name,
462+
)
463+
464+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
465+
return ConvOpEvalSnippet(
466+
op_info=op_info,
467+
templ_dtypes=[self.out_dtypes[0]],
468+
op_name=op_var_name,
469+
tensor_var_map=tensor_var_map,
470+
)
471+
439472

440473
@OperatorFactory.register
441474
class _QuantizedFullyConnectedOperator(_CommonParams):
@@ -464,3 +497,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
464497
op_name=op_var_name,
465498
tensor_var_map=tensor_var_map,
466499
)
500+
501+
@OperatorFactory.register
502+
class _BatchNormOperator(_CommonParams):
503+
op_type = "BatchNormOperator"
504+
505+
@classmethod
506+
@must_return_type(Hashable)
507+
def get_constructor_parameters(cls, op_info):
508+
strides = [
509+
1,
510+
op_info.op_attr['StrideW'],
511+
op_info.op_attr['StrideH'],
512+
1,
513+
]
514+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
515+
strides_str = ','.join(map(str, strides))
516+
return ("{{ {} }}".format(strides_str), padding)
517+
518+
def get_declare_snippet(self, op_var_name, tensor_var_map):
519+
return DeclareOpSnippet(
520+
op=self,
521+
templ_dtypes=[self.out_dtypes[0]],
522+
op_var_name=op_var_name,
523+
)
524+
525+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
526+
return BatchNormSnippet(
527+
op_info=op_info,
528+
templ_dtypes=[self.out_dtypes[0]],
529+
op_name=op_var_name,
530+
tensor_var_map=tensor_var_map,
531+
)
532+
533+
@OperatorFactory.register
534+
class _MeanOperator(_CommonParams):
535+
op_type = "MeanOperator"
536+
537+
@classmethod
538+
@must_return_type(Hashable)
539+
def get_constructor_parameters(cls, op_info):
540+
keep_dims = str(op_info.op_attr["keep_dims"])
541+
return (" {} ".format(keep_dims), )
542+
543+
def get_declare_snippet(self, op_var_name, tensor_var_map):
544+
return DeclareOpSnippet(
545+
op=self,
546+
templ_dtypes=[self.out_dtypes[0]],
547+
op_var_name=op_var_name,
548+
)
549+
550+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
551+
return BatchNormSnippet(
552+
op_info=op_info,
553+
templ_dtypes=[self.out_dtypes[0]],
554+
op_name=op_var_name,
555+
tensor_var_map=tensor_var_map,
556+
)
557+
558+
@OperatorFactory.register
559+
class _SoftmaxOperator(_CommonParams):
560+
op_type = "SoftmaxOperator"
561+
562+
@classmethod
563+
@must_return_type(Hashable)
564+
def get_constructor_parameters(cls, op_info):
565+
Beta = op_info.op_attr["Beta"]
566+
return (" %f " % Beta,)
567+
568+
def get_declare_snippet(self, op_var_name, tensor_var_map):
569+
return DeclareOpSnippet(
570+
op=self,
571+
templ_dtypes=[self.out_dtypes[0]],
572+
op_var_name=op_var_name,
573+
)
574+
575+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
576+
return BatchNormSnippet(
577+
op_info=op_info,
578+
templ_dtypes=[self.out_dtypes[0]],
579+
op_name=op_var_name,
580+
tensor_var_map=tensor_var_map,
581+
)
582+
583+
@OperatorFactory.register
584+
class _MulOperator(_Operator):
585+
op_type = 'MulOperator'
586+
587+
def get_declare_snippet(self, op_var_name, tensor_var_map):
588+
return DeclareOpSnippet(
589+
op=self,
590+
templ_dtypes=[self.in_dtypes[0]],
591+
op_var_name=op_var_name,
592+
)
593+
594+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
595+
return MulOpEvalSnippet(
596+
op_info=op_info,
597+
templ_dtypes=[self.in_dtypes[0]],
598+
op_name=op_var_name,
599+
tensor_var_map=tensor_var_map,
600+
)
601+
602+
@OperatorFactory.register
603+
class _SubOperator(_Operator):
604+
op_type = 'SubOperator'
605+
606+
def get_declare_snippet(self, op_var_name, tensor_var_map):
607+
return DeclareOpSnippet(
608+
op=self,
609+
templ_dtypes=[self.in_dtypes[0]],
610+
op_var_name=op_var_name,
611+
)
612+
613+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
614+
return SubOpEvalSnippet(
615+
op_info=op_info,
616+
templ_dtypes=[self.in_dtypes[0]],
617+
op_name=op_var_name,
618+
tensor_var_map=tensor_var_map,
619+
)

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
"MinPoolEvalSnippet",
2828
"MaxPoolEvalSnippet",
2929
"QuantizedFullyConnectedSnippet",
30+
"BatchNormSnippet",
31+
"MulOpEvalSnippet",
32+
"SubOpEvalSnippet",
33+
"ConvOpEvalSnippet",
34+
"MeanOpEvalSnippet",
35+
"SoftmaxOpEvalSnippet",
3036
"SimpleContainer",
3137
]
3238

@@ -143,6 +149,9 @@ class DepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
143149
__inputs__ = ["in", "depthwise_filter", "pointwise_filter"]
144150
__outputs__ = ["out"]
145151

152+
class ConvOpEvalSnippet(OpEvalSnippet):
153+
__inputs__ = ["in", "filter"]
154+
__outputs__ = ["out"]
146155

147156
class QuantDepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
148157
__inputs__ = ["in", "filter", "bias"]
@@ -218,6 +227,23 @@ class QuantizedFullyConnectedSnippet(OpEvalSnippet):
218227
__inputs__ = ["input", "filter", "bias"]
219228
__outputs__ = ["output"]
220229

230+
class BatchNormSnippet(OpEvalSnippet):
231+
__inputs__ = ["x", "mean", "variance", "offset", "scale"]
232+
__outputs__ = ["output"]
233+
234+
class MulOpEvalSnippet(OpEvalSnippet):
235+
__inputs__ = ['a', 'b']
236+
__outputs__ = ['c']
237+
class SubOpEvalSnippet(OpEvalSnippet):
238+
__inputs__ = ['a', 'b']
239+
__outputs__ = ['c']
240+
class MeanOpEvalSnippet(OpEvalSnippet):
241+
__inputs__ = ['input', 'axis']
242+
__outputs__ = ['output']
243+
class SoftmaxOpEvalSnippet(OpEvalSnippet):
244+
__inputs__ = ['input']
245+
__outputs__ = ['output']
246+
221247

222248
class SimpleContainer(SnippetBase):
223249
__headers__ = set(['"uTensor.h"', "<vector>"])

0 commit comments

Comments
 (0)