@@ -511,6 +511,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
511511 tensor_var_map = tensor_var_map ,
512512 )
513513
514+ @OperatorFactory .register
515+ class _ConvOperator (_CommonParams ):
516+ op_type = "ConvOperator"
517+
518+ @classmethod
519+ @must_return_type (Hashable )
520+ def get_constructor_parameters (cls , op_info ):
521+
522+ strides = [
523+ 1 ,
524+ op_info .op_attr ['StrideW' ],
525+ op_info .op_attr ['StrideH' ],
526+ 1 ,
527+ ]
528+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
529+ strides_str = ',' .join (map (str , strides ))
530+ return ("{{ {} }}" .format (strides_str ), padding )
531+
532+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
533+ return DeclareOpSnippet (
534+ op = self ,
535+ templ_dtypes = [self .out_dtypes [0 ]],
536+ op_var_name = op_var_name ,
537+ )
538+
539+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
540+ return ConvOpEvalSnippet (
541+ op_info = op_info ,
542+ templ_dtypes = [self .out_dtypes [0 ]],
543+ op_name = op_var_name ,
544+ tensor_var_map = tensor_var_map ,
545+ )
546+
514547
515548@OperatorFactory .register
516549class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -689,3 +722,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
689722 tensor_var_map = tensor_var_map ,
690723 nested_namespaces = type (self ).namespaces ,
691724 )
725+
726+ @OperatorFactory .register
727+ class _BatchNormOperator (_CommonParams ):
728+ op_type = "BatchNormOperator"
729+
730+ @classmethod
731+ @must_return_type (Hashable )
732+ def get_constructor_parameters (cls , op_info ):
733+ strides = [
734+ 1 ,
735+ op_info .op_attr ['StrideW' ],
736+ op_info .op_attr ['StrideH' ],
737+ 1 ,
738+ ]
739+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
740+ strides_str = ',' .join (map (str , strides ))
741+ return ("{{ {} }}" .format (strides_str ), padding )
742+
743+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
744+ return DeclareOpSnippet (
745+ op = self ,
746+ templ_dtypes = [self .out_dtypes [0 ]],
747+ op_var_name = op_var_name ,
748+ )
749+
750+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
751+ return BatchNormSnippet (
752+ op_info = op_info ,
753+ templ_dtypes = [self .out_dtypes [0 ]],
754+ op_name = op_var_name ,
755+ tensor_var_map = tensor_var_map ,
756+ )
757+
758+ @OperatorFactory .register
759+ class _MeanOperator (_CommonParams ):
760+ op_type = "MeanOperator"
761+
762+ @classmethod
763+ @must_return_type (Hashable )
764+ def get_constructor_parameters (cls , op_info ):
765+ keep_dims = str (op_info .op_attr ["keep_dims" ])
766+ return (" {} " .format (keep_dims ), )
767+
768+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
769+ return DeclareOpSnippet (
770+ op = self ,
771+ templ_dtypes = [self .out_dtypes [0 ]],
772+ op_var_name = op_var_name ,
773+ )
774+
775+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
776+ return BatchNormSnippet (
777+ op_info = op_info ,
778+ templ_dtypes = [self .out_dtypes [0 ]],
779+ op_name = op_var_name ,
780+ tensor_var_map = tensor_var_map ,
781+ )
782+
783+ @OperatorFactory .register
784+ class _SoftmaxOperator (_CommonParams ):
785+ op_type = "SoftmaxOperator"
786+
787+ @classmethod
788+ @must_return_type (Hashable )
789+ def get_constructor_parameters (cls , op_info ):
790+ Beta = op_info .op_attr ["Beta" ]
791+ return (" %f " % Beta ,)
792+
793+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
794+ return DeclareOpSnippet (
795+ op = self ,
796+ templ_dtypes = [self .out_dtypes [0 ]],
797+ op_var_name = op_var_name ,
798+ )
799+
800+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
801+ return BatchNormSnippet (
802+ op_info = op_info ,
803+ templ_dtypes = [self .out_dtypes [0 ]],
804+ op_name = op_var_name ,
805+ tensor_var_map = tensor_var_map ,
806+ )
807+
808+ @OperatorFactory .register
809+ class _MulOperator (_Operator ):
810+ op_type = 'MulOperator'
811+
812+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
813+ return DeclareOpSnippet (
814+ op = self ,
815+ templ_dtypes = [self .in_dtypes [0 ]],
816+ op_var_name = op_var_name ,
817+ )
818+
819+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
820+ return MulOpEvalSnippet (
821+ op_info = op_info ,
822+ templ_dtypes = [self .in_dtypes [0 ]],
823+ op_name = op_var_name ,
824+ tensor_var_map = tensor_var_map ,
825+ )
826+
827+ @OperatorFactory .register
828+ class _SubOperator (_Operator ):
829+ op_type = 'SubOperator'
830+
831+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
832+ return DeclareOpSnippet (
833+ op = self ,
834+ templ_dtypes = [self .in_dtypes [0 ]],
835+ op_var_name = op_var_name ,
836+ )
837+
838+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
839+ return SubOpEvalSnippet (
840+ op_info = op_info ,
841+ templ_dtypes = [self .in_dtypes [0 ]],
842+ op_name = op_var_name ,
843+ tensor_var_map = tensor_var_map ,
844+ )
0 commit comments