Skip to content

Commit 9a04967

Browse files
committed
Add BatchNorm, Mul, Add, Sub, Mean, and more ops
1 parent 2ac0ee4 commit 9a04967

File tree

5 files changed

+360
-13
lines changed

5 files changed

+360
-13
lines changed

utensor_cgen/backend/utensor/_graph_lower/_op_lower.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class uTensorRearchGraphLower(uTensorGraphLowerBase):
4040
class OptypeRenameManager(object):
4141
NAME_MAP = {
4242
'Add': 'AddOperator',
43+
'Mul': 'MulOperator',
44+
'Sub': 'SubOperator',
4345
'Conv2D': 'ConvOperator',
4446
'MatMul': 'MatrixMultOperator'
4547
}

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
478478
nested_namespaces=type(self).namespaces,
479479
)
480480

481+
@OperatorFactory.register
482+
class _ConvOperator(_CommonParams):
483+
op_type = "ConvOperator"
484+
485+
@classmethod
486+
@must_return_type(Hashable)
487+
def get_constructor_parameters(cls, op_info):
488+
489+
strides = [
490+
1,
491+
op_info.op_attr['StrideW'],
492+
op_info.op_attr['StrideH'],
493+
1,
494+
]
495+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
496+
strides_str = ','.join(map(str, strides))
497+
return ("{{ {} }}".format(strides_str), padding)
498+
499+
def get_declare_snippet(self, op_var_name, tensor_var_map):
500+
return DeclareOpSnippet(
501+
op=self,
502+
templ_dtypes=[self.out_dtypes[0]],
503+
op_var_name=op_var_name,
504+
)
505+
506+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
507+
return ConvOpEvalSnippet(
508+
op_info=op_info,
509+
templ_dtypes=[self.out_dtypes[0]],
510+
op_name=op_var_name,
511+
tensor_var_map=tensor_var_map,
512+
)
513+
481514

482515
@OperatorFactory.register
483516
class _QuantizedFullyConnectedOperator(_CommonParams):
@@ -521,3 +554,138 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
521554
return MissingOpEvalSnippet(op_info, tensor_var_map)
522555

523556
OperatorFactory._operators[_MissingOperator.op_type] = _MissingOperator
557+
558+
@OperatorFactory.register
559+
class _BatchNormOperator(_Operator):
560+
namespaces = ('ReferenceOperators',)
561+
op_type = "BatchNormOperator"
562+
563+
@classmethod
564+
@must_return_type(Hashable)
565+
def get_constructor_parameters(cls, op_info):
566+
strides = [
567+
1,
568+
op_info.op_attr['StrideW'],
569+
op_info.op_attr['StrideH'],
570+
1,
571+
]
572+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
573+
strides_str = ','.join(map(str, strides))
574+
return ("{{ {} }}".format(strides_str), padding)
575+
576+
def get_declare_snippet(self, op_var_name, tensor_var_map):
577+
return DeclareOpSnippet(
578+
op=self,
579+
templ_dtypes=[self.out_dtypes[0]],
580+
op_var_name=op_var_name,
581+
nested_namespaces=type(self).namespaces,
582+
)
583+
584+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
585+
return BatchNormSnippet(
586+
op_info=op_info,
587+
templ_dtypes=[self.out_dtypes[0]],
588+
op_name=op_var_name,
589+
tensor_var_map=tensor_var_map,
590+
nested_namespaces=type(self).namespaces,
591+
)
592+
593+
@OperatorFactory.register
594+
class _MeanOperator(_Operator):
595+
namespaces = ('ReferenceOperators',)
596+
op_type = "MeanOperator"
597+
598+
@classmethod
599+
@must_return_type(Hashable)
600+
def get_constructor_parameters(cls, op_info):
601+
keep_dims = str(op_info.op_attr["keep_dims"])
602+
return (" {} ".format(keep_dims), )
603+
604+
def get_declare_snippet(self, op_var_name, tensor_var_map):
605+
return DeclareOpSnippet(
606+
op=self,
607+
templ_dtypes=[self.out_dtypes[0]],
608+
op_var_name=op_var_name,
609+
nested_namespaces=type(self).namespaces,
610+
)
611+
612+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
613+
return BatchNormSnippet(
614+
op_info=op_info,
615+
templ_dtypes=[self.out_dtypes[0]],
616+
op_name=op_var_name,
617+
tensor_var_map=tensor_var_map,
618+
nested_namespaces=type(self).namespaces,
619+
)
620+
621+
@OperatorFactory.register
622+
class _SoftmaxOperator(_CommonParams):
623+
namespaces = ('ReferenceOperators',)
624+
op_type = "SoftmaxOperator"
625+
626+
@classmethod
627+
@must_return_type(Hashable)
628+
def get_constructor_parameters(cls, op_info):
629+
Beta = op_info.op_attr["Beta"]
630+
return (" %f " % Beta,)
631+
632+
def get_declare_snippet(self, op_var_name, tensor_var_map):
633+
return DeclareOpSnippet(
634+
op=self,
635+
templ_dtypes=[self.out_dtypes[0]],
636+
op_var_name=op_var_name,
637+
nested_namespaces=type(self).namespaces,
638+
)
639+
640+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
641+
return BatchNormSnippet(
642+
op_info=op_info,
643+
templ_dtypes=[self.out_dtypes[0]],
644+
op_name=op_var_name,
645+
tensor_var_map=tensor_var_map,
646+
nested_namespaces=type(self).namespaces,
647+
)
648+
649+
@OperatorFactory.register
650+
class _MulOperator(_Operator):
651+
namespaces = ('ReferenceOperators',)
652+
op_type = 'MulOperator'
653+
654+
def get_declare_snippet(self, op_var_name, tensor_var_map):
655+
return DeclareOpSnippet(
656+
op=self,
657+
templ_dtypes=[self.in_dtypes[0]],
658+
op_var_name=op_var_name,
659+
nested_namespaces=type(self).namespaces,
660+
)
661+
662+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
663+
return MulOpEvalSnippet(
664+
op_info=op_info,
665+
templ_dtypes=[self.in_dtypes[0]],
666+
op_name=op_var_name,
667+
tensor_var_map=tensor_var_map,
668+
nested_namespaces=type(self).namespaces,
669+
)
670+
671+
@OperatorFactory.register
672+
class _SubOperator(_Operator):
673+
namespaces = ('ReferenceOperators',)
674+
op_type = 'SubOperator'
675+
676+
def get_declare_snippet(self, op_var_name, tensor_var_map):
677+
return DeclareOpSnippet(
678+
op=self,
679+
templ_dtypes=[self.in_dtypes[0]],
680+
op_var_name=op_var_name,
681+
nested_namespaces=type(self).namespaces,
682+
)
683+
684+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
685+
return SubOpEvalSnippet(
686+
op_info=op_info,
687+
templ_dtypes=[self.in_dtypes[0]],
688+
op_name=op_var_name,
689+
tensor_var_map=tensor_var_map,
690+
nested_namespaces=type(self).namespaces,
691+
)

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
"QuantizedFullyConnectedSnippet",
3131
"MissingOpEvalSnippet",
3232
"TimeSlotContainer",
33+
"BatchNormSnippet",
34+
"MulOpEvalSnippet",
35+
"SubOpEvalSnippet",
36+
"ConvOpEvalSnippet",
37+
"MeanOpEvalSnippet",
38+
"SoftmaxOpEvalSnippet",
3339
"SimpleContainer",
3440
]
3541

@@ -156,6 +162,9 @@ class DepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
156162
__inputs__ = ["in", "depthwise_filter", "pointwise_filter"]
157163
__outputs__ = ["out"]
158164

165+
class ConvOpEvalSnippet(OpEvalSnippet):
166+
__inputs__ = ["in", "filter"]
167+
__outputs__ = ["out"]
159168

160169
class QuantDepthwiseSeperateConvOpEvalSnippet(OpEvalSnippet):
161170
__inputs__ = ["in", "filter", "bias"]
@@ -231,6 +240,23 @@ class QuantizedFullyConnectedSnippet(OpEvalSnippet):
231240
__inputs__ = ["input", "filter", "bias"]
232241
__outputs__ = ["output"]
233242

243+
class BatchNormSnippet(OpEvalSnippet):
244+
__inputs__ = ["x", "mean", "variance", "offset", "scale"]
245+
__outputs__ = ["output"]
246+
247+
class MulOpEvalSnippet(OpEvalSnippet):
248+
__inputs__ = ['a', 'b']
249+
__outputs__ = ['c']
250+
class SubOpEvalSnippet(OpEvalSnippet):
251+
__inputs__ = ['a', 'b']
252+
__outputs__ = ['c']
253+
class MeanOpEvalSnippet(OpEvalSnippet):
254+
__inputs__ = ['input', 'axis']
255+
__outputs__ = ['output']
256+
class SoftmaxOpEvalSnippet(OpEvalSnippet):
257+
__inputs__ = ['input']
258+
__outputs__ = ['output']
259+
234260

235261
class MissingOpEvalSnippet(OpEvalSnippet):
236262
__template_name__ = "snippets/rearch/op_missing.cpp"

0 commit comments

Comments
 (0)