Skip to content

Commit b6975f9

Browse files
committed
Add BatchNorm, Mul, Add, Sub, Mean, and more ops
1 parent 373d1a7 commit b6975f9

File tree

5 files changed

+395
-10
lines changed

5 files changed

+395
-10
lines changed

utensor_cgen/backend/utensor/_graph_lower/_op_lower.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def __init__(self, config):
4141
class OptypeRenameManager(object):
4242
NAME_MAP = {
4343
'Add': 'AddOperator',
44+
'Mul': 'MulOperator',
45+
'Sub': 'SubOperator',
4446
'Conv2D': 'ConvOperator',
4547
'MatMul': 'MatrixMultOperator'
4648
}

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
405405
tensor_var_map=tensor_var_map,
406406
)
407407

408+
@OperatorFactory.register
409+
class _ConvOperator(_CommonParams):
410+
op_type = "ConvOperator"
411+
412+
@classmethod
413+
@must_return_type(Hashable)
414+
def get_constructor_parameters(cls, op_info):
415+
416+
strides = [
417+
1,
418+
op_info.op_attr['StrideW'],
419+
op_info.op_attr['StrideH'],
420+
1,
421+
]
422+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
423+
strides_str = ','.join(map(str, strides))
424+
return ("{{ {} }}".format(strides_str), padding)
425+
426+
def get_declare_snippet(self, op_var_name, tensor_var_map):
427+
return DeclareOpSnippet(
428+
op=self,
429+
templ_dtypes=[self.out_dtypes[0]],
430+
op_var_name=op_var_name,
431+
)
432+
433+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
434+
return ConvOpEvalSnippet(
435+
op_info=op_info,
436+
templ_dtypes=[self.out_dtypes[0]],
437+
op_name=op_var_name,
438+
tensor_var_map=tensor_var_map,
439+
)
440+
408441

409442
@OperatorFactory.register
410443
class _QuantizedFullyConnectedOperator(_CommonParams):
@@ -433,3 +466,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
433466
op_name=op_var_name,
434467
tensor_var_map=tensor_var_map,
435468
)
469+
470+
@OperatorFactory.register
471+
class _BatchNormOperator(_CommonParams):
472+
op_type = "BatchNormOperator"
473+
474+
@classmethod
475+
@must_return_type(Hashable)
476+
def get_constructor_parameters(cls, op_info):
477+
strides = [
478+
1,
479+
op_info.op_attr['StrideW'],
480+
op_info.op_attr['StrideH'],
481+
1,
482+
]
483+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
484+
strides_str = ','.join(map(str, strides))
485+
return ("{{ {} }}".format(strides_str), padding)
486+
487+
def get_declare_snippet(self, op_var_name, tensor_var_map):
488+
return DeclareOpSnippet(
489+
op=self,
490+
templ_dtypes=[self.out_dtypes[0]],
491+
op_var_name=op_var_name,
492+
)
493+
494+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
495+
return BatchNormSnippet(
496+
op_info=op_info,
497+
templ_dtypes=[self.out_dtypes[0]],
498+
op_name=op_var_name,
499+
tensor_var_map=tensor_var_map,
500+
)
501+
502+
@OperatorFactory.register
503+
class _MeanOperator(_CommonParams):
504+
op_type = "MeanOperator"
505+
506+
@classmethod
507+
@must_return_type(Hashable)
508+
def get_constructor_parameters(cls, op_info):
509+
keep_dims = str(op_info.op_attr["keep_dims"])
510+
return (" {} ".format(keep_dims), )
511+
512+
def get_declare_snippet(self, op_var_name, tensor_var_map):
513+
return DeclareOpSnippet(
514+
op=self,
515+
templ_dtypes=[self.out_dtypes[0]],
516+
op_var_name=op_var_name,
517+
)
518+
519+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
520+
return BatchNormSnippet(
521+
op_info=op_info,
522+
templ_dtypes=[self.out_dtypes[0]],
523+
op_name=op_var_name,
524+
tensor_var_map=tensor_var_map,
525+
)
526+
527+
@OperatorFactory.register
528+
class _SoftmaxOperator(_CommonParams):
529+
op_type = "SoftmaxOperator"
530+
531+
@classmethod
532+
@must_return_type(Hashable)
533+
def get_constructor_parameters(cls, op_info):
534+
Beta = op_info.op_attr["Beta"]
535+
return (" %f " % Beta,)
536+
537+
def get_declare_snippet(self, op_var_name, tensor_var_map):
538+
return DeclareOpSnippet(
539+
op=self,
540+
templ_dtypes=[self.out_dtypes[0]],
541+
op_var_name=op_var_name,
542+
)
543+
544+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
545+
return BatchNormSnippet(
546+
op_info=op_info,
547+
templ_dtypes=[self.out_dtypes[0]],
548+
op_name=op_var_name,
549+
tensor_var_map=tensor_var_map,
550+
)
551+
552+
@OperatorFactory.register
553+
class _MulOperator(_Operator):
554+
op_type = 'MulOperator'
555+
556+
def get_declare_snippet(self, op_var_name, tensor_var_map):
557+
return DeclareOpSnippet(
558+
op=self,
559+
templ_dtypes=[self.in_dtypes[0]],
560+
op_var_name=op_var_name,
561+
)
562+
563+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
564+
return MulOpEvalSnippet(
565+
op_info=op_info,
566+
templ_dtypes=[self.in_dtypes[0]],
567+
op_name=op_var_name,
568+
tensor_var_map=tensor_var_map,
569+
)
570+
571+
@OperatorFactory.register
572+
class _SubOperator(_Operator):
573+
op_type = 'SubOperator'
574+
575+
def get_declare_snippet(self, op_var_name, tensor_var_map):
576+
return DeclareOpSnippet(
577+
op=self,
578+
templ_dtypes=[self.in_dtypes[0]],
579+
op_var_name=op_var_name,
580+
)
581+
582+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
583+
return SubOpEvalSnippet(
584+
op_info=op_info,
585+
templ_dtypes=[self.in_dtypes[0]],
586+
op_name=op_var_name,
587+
tensor_var_map=tensor_var_map,
588+
)

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
"MinPoolEvalSnippet",
2727
"MaxPoolEvalSnippet",
2828
"QuantizedFullyConnectedSnippet",
29+
"BatchNormSnippet",
30+
"MulOpEvalSnippet",
31+
"SubOpEvalSnippet",
32+
"ConvOpEvalSnippet",
33+
"MeanOpEvalSnippet",
34+
"SoftmaxOpEvalSnippet",
2935
"SimpleContainer",
3036
]
3137

@@ -135,6 +141,9 @@ class DepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
135141
__inputs__ = ["in", "depthwise_filter", "pointwise_filter"]
136142
__outputs__ = ["out"]
137143

144+
class ConvOpEvalSnippet(OpEvalSnippet):
145+
__inputs__ = ["in", "filter"]
146+
__outputs__ = ["out"]
138147

139148
class QuantDepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
140149
__inputs__ = ["in", "filter", "bias"]
@@ -210,6 +219,23 @@ class QuantizedFullyConnectedSnippet(OpEvalSnippet):
210219
__inputs__ = ["input", "filter", "bias"]
211220
__outputs__ = ["output"]
212221

222+
class BatchNormSnippet(OpEvalSnippet):
223+
__inputs__ = ["x", "mean", "variance", "offset", "scale"]
224+
__outputs__ = ["output"]
225+
226+
class MulOpEvalSnippet(OpEvalSnippet):
227+
__inputs__ = ['a', 'b']
228+
__outputs__ = ['c']
229+
class SubOpEvalSnippet(OpEvalSnippet):
230+
__inputs__ = ['a', 'b']
231+
__outputs__ = ['c']
232+
class MeanOpEvalSnippet(OpEvalSnippet):
233+
__inputs__ = ['input', 'axis']
234+
__outputs__ = ['output']
235+
class SoftmaxOpEvalSnippet(OpEvalSnippet):
236+
__inputs__ = ['input']
237+
__outputs__ = ['output']
238+
213239

214240
class SimpleContainer(SnippetBase):
215241
__headers__ = set(['"uTensor.h"', "<vector>"])

0 commit comments

Comments
 (0)