@@ -478,6 +478,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
478478 nested_namespaces = type (self ).namespaces ,
479479 )
480480
481+ @OperatorFactory .register
482+ class _ConvOperator (_CommonParams ):
483+ op_type = "ConvOperator"
484+
485+ @classmethod
486+ @must_return_type (Hashable )
487+ def get_constructor_parameters (cls , op_info ):
488+
489+ strides = [
490+ 1 ,
491+ op_info .op_attr ['StrideW' ],
492+ op_info .op_attr ['StrideH' ],
493+ 1 ,
494+ ]
495+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
496+ strides_str = ',' .join (map (str , strides ))
497+ return ("{{ {} }}" .format (strides_str ), padding )
498+
499+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
500+ return DeclareOpSnippet (
501+ op = self ,
502+ templ_dtypes = [self .out_dtypes [0 ]],
503+ op_var_name = op_var_name ,
504+ )
505+
506+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
507+ return ConvOpEvalSnippet (
508+ op_info = op_info ,
509+ templ_dtypes = [self .out_dtypes [0 ]],
510+ op_name = op_var_name ,
511+ tensor_var_map = tensor_var_map ,
512+ )
513+
481514
482515@OperatorFactory .register
483516class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -521,3 +554,138 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
521554 return MissingOpEvalSnippet (op_info , tensor_var_map )
522555
523556OperatorFactory ._operators [_MissingOperator .op_type ] = _MissingOperator
557+
558+ @OperatorFactory .register
559+ class _BatchNormOperator (_Operator ):
560+ namespaces = ('ReferenceOperators' ,)
561+ op_type = "BatchNormOperator"
562+
563+ @classmethod
564+ @must_return_type (Hashable )
565+ def get_constructor_parameters (cls , op_info ):
566+ strides = [
567+ 1 ,
568+ op_info .op_attr ['StrideW' ],
569+ op_info .op_attr ['StrideH' ],
570+ 1 ,
571+ ]
572+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
573+ strides_str = ',' .join (map (str , strides ))
574+ return ("{{ {} }}" .format (strides_str ), padding )
575+
576+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
577+ return DeclareOpSnippet (
578+ op = self ,
579+ templ_dtypes = [self .out_dtypes [0 ]],
580+ op_var_name = op_var_name ,
581+ nested_namespaces = type (self ).namespaces ,
582+ )
583+
584+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
585+ return BatchNormSnippet (
586+ op_info = op_info ,
587+ templ_dtypes = [self .out_dtypes [0 ]],
588+ op_name = op_var_name ,
589+ tensor_var_map = tensor_var_map ,
590+ nested_namespaces = type (self ).namespaces ,
591+ )
592+
593+ @OperatorFactory .register
594+ class _MeanOperator (_Operator ):
595+ namespaces = ('ReferenceOperators' ,)
596+ op_type = "MeanOperator"
597+
598+ @classmethod
599+ @must_return_type (Hashable )
600+ def get_constructor_parameters (cls , op_info ):
601+ keep_dims = str (op_info .op_attr ["keep_dims" ])
602+ return (" {} " .format (keep_dims ), )
603+
604+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
605+ return DeclareOpSnippet (
606+ op = self ,
607+ templ_dtypes = [self .out_dtypes [0 ]],
608+ op_var_name = op_var_name ,
609+ nested_namespaces = type (self ).namespaces ,
610+ )
611+
612+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
613+ return BatchNormSnippet (
614+ op_info = op_info ,
615+ templ_dtypes = [self .out_dtypes [0 ]],
616+ op_name = op_var_name ,
617+ tensor_var_map = tensor_var_map ,
618+ nested_namespaces = type (self ).namespaces ,
619+ )
620+
621+ @OperatorFactory .register
622+ class _SoftmaxOperator (_CommonParams ):
623+ namespaces = ('ReferenceOperators' ,)
624+ op_type = "SoftmaxOperator"
625+
626+ @classmethod
627+ @must_return_type (Hashable )
628+ def get_constructor_parameters (cls , op_info ):
629+ Beta = op_info .op_attr ["Beta" ]
630+ return (" %f " % Beta ,)
631+
632+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
633+ return DeclareOpSnippet (
634+ op = self ,
635+ templ_dtypes = [self .out_dtypes [0 ]],
636+ op_var_name = op_var_name ,
637+ nested_namespaces = type (self ).namespaces ,
638+ )
639+
640+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
641+ return BatchNormSnippet (
642+ op_info = op_info ,
643+ templ_dtypes = [self .out_dtypes [0 ]],
644+ op_name = op_var_name ,
645+ tensor_var_map = tensor_var_map ,
646+ nested_namespaces = type (self ).namespaces ,
647+ )
648+
649+ @OperatorFactory .register
650+ class _MulOperator (_Operator ):
651+ namespaces = ('ReferenceOperators' ,)
652+ op_type = 'MulOperator'
653+
654+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
655+ return DeclareOpSnippet (
656+ op = self ,
657+ templ_dtypes = [self .in_dtypes [0 ]],
658+ op_var_name = op_var_name ,
659+ nested_namespaces = type (self ).namespaces ,
660+ )
661+
662+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
663+ return MulOpEvalSnippet (
664+ op_info = op_info ,
665+ templ_dtypes = [self .in_dtypes [0 ]],
666+ op_name = op_var_name ,
667+ tensor_var_map = tensor_var_map ,
668+ nested_namespaces = type (self ).namespaces ,
669+ )
670+
671+ @OperatorFactory .register
672+ class _SubOperator (_Operator ):
673+ namespaces = ('ReferenceOperators' ,)
674+ op_type = 'SubOperator'
675+
676+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
677+ return DeclareOpSnippet (
678+ op = self ,
679+ templ_dtypes = [self .in_dtypes [0 ]],
680+ op_var_name = op_var_name ,
681+ nested_namespaces = type (self ).namespaces ,
682+ )
683+
684+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
685+ return SubOpEvalSnippet (
686+ op_info = op_info ,
687+ templ_dtypes = [self .in_dtypes [0 ]],
688+ op_name = op_var_name ,
689+ tensor_var_map = tensor_var_map ,
690+ nested_namespaces = type (self ).namespaces ,
691+ )
0 commit comments