@@ -544,6 +544,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
544544 tensor_var_map = tensor_var_map ,
545545 )
546546
547+ @OperatorFactory .register
548+ class _ConvOperator (_CommonParams ):
549+ op_type = "ConvOperator"
550+
551+ @classmethod
552+ @must_return_type (Hashable )
553+ def get_constructor_parameters (cls , op_info ):
554+
555+ strides = [
556+ 1 ,
557+ op_info .op_attr ['StrideW' ],
558+ op_info .op_attr ['StrideH' ],
559+ 1 ,
560+ ]
561+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
562+ strides_str = ',' .join (map (str , strides ))
563+ return ("{{ {} }}" .format (strides_str ), padding )
564+
565+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
566+ return DeclareOpSnippet (
567+ op = self ,
568+ templ_dtypes = [self .out_dtypes [0 ]],
569+ op_var_name = op_var_name ,
570+ )
571+
572+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
573+ return ConvOpEvalSnippet (
574+ op_info = op_info ,
575+ templ_dtypes = [self .out_dtypes [0 ]],
576+ op_name = op_var_name ,
577+ tensor_var_map = tensor_var_map ,
578+ )
579+
547580
548581@OperatorFactory .register
549582class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -842,3 +875,142 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
842875 op_name = op_var_name ,
843876 tensor_var_map = tensor_var_map ,
844877 )
878+
879+ @OperatorFactory .register
880+ class _BatchNormOperator (_CommonParams ):
881+ op_type = "BatchNormOperator"
882+
883+ @classmethod
884+ @must_return_type (Hashable )
885+ def get_constructor_parameters (cls , op_info ):
886+ strides = [
887+ 1 ,
888+ op_info .op_attr ['StrideW' ],
889+ op_info .op_attr ['StrideH' ],
890+ 1 ,
891+ ]
892+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
893+ strides_str = ',' .join (map (str , strides ))
894+ return ("{{ {} }}" .format (strides_str ), padding )
895+
896+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
897+ return DeclareOpSnippet (
898+ op = self ,
899+ templ_dtypes = [self .out_dtypes [0 ]],
900+ op_var_name = op_var_name ,
901+ )
902+
903+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
904+ return BatchNormSnippet (
905+ op_info = op_info ,
906+ templ_dtypes = [self .out_dtypes [0 ]],
907+ op_name = op_var_name ,
908+ tensor_var_map = tensor_var_map ,
909+ )
910+
911+ @OperatorFactory .register
912+ class _MeanOperator (_CommonParams ):
913+ op_type = "MeanOperator"
914+
915+ @classmethod
916+ @must_return_type (Hashable )
917+ def get_constructor_parameters (cls , op_info ):
918+ keep_dims = str (op_info .op_attr ["keep_dims" ])
919+ return (" {} " .format (keep_dims ), )
920+
921+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
922+ return DeclareOpSnippet (
923+ op = self ,
924+ templ_dtypes = [self .out_dtypes [0 ]],
925+ op_var_name = op_var_name ,
926+ )
927+
928+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
929+ return BatchNormSnippet (
930+ op_info = op_info ,
931+ templ_dtypes = [self .out_dtypes [0 ]],
932+ op_name = op_var_name ,
933+ tensor_var_map = tensor_var_map ,
934+ )
935+
936+ @OperatorFactory .register
937+ class _SoftmaxOperator (_CommonParams ):
938+ op_type = "SoftmaxOperator"
939+
940+ @classmethod
941+ @must_return_type (Hashable )
942+ def get_constructor_parameters (cls , op_info ):
943+ Beta = op_info .op_attr ["Beta" ]
944+ return (" %f " % Beta ,)
945+
946+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
947+ return DeclareOpSnippet (
948+ op = self ,
949+ templ_dtypes = [self .out_dtypes [0 ]],
950+ op_var_name = op_var_name ,
951+ )
952+
953+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
954+ return BatchNormSnippet (
955+ op_info = op_info ,
956+ templ_dtypes = [self .out_dtypes [0 ]],
957+ op_name = op_var_name ,
958+ tensor_var_map = tensor_var_map ,
959+ )
960+
961+ @OperatorFactory .register
962+ class _MulOperator (_Operator ):
963+ op_type = 'MulOperator'
964+
965+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
966+ return DeclareOpSnippet (
967+ op = self ,
968+ templ_dtypes = [self .in_dtypes [0 ]],
969+ op_var_name = op_var_name ,
970+ )
971+
972+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
973+ return MulOpEvalSnippet (
974+ op_info = op_info ,
975+ templ_dtypes = [self .in_dtypes [0 ]],
976+ op_name = op_var_name ,
977+ tensor_var_map = tensor_var_map ,
978+ )
979+
980+ @OperatorFactory .register
981+ class _SubOperator (_Operator ):
982+ op_type = 'SubOperator'
983+
984+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
985+ return DeclareOpSnippet (
986+ op = self ,
987+ templ_dtypes = [self .in_dtypes [0 ]],
988+ op_var_name = op_var_name ,
989+ )
990+
991+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
992+ return SubOpEvalSnippet (
993+ op_info = op_info ,
994+ templ_dtypes = [self .in_dtypes [0 ]],
995+ op_name = op_var_name ,
996+ tensor_var_map = tensor_var_map ,
997+ )
998+
999+ @OperatorFactory .register
1000+ class _SigmoidOperator (_Operator ):
1001+ op_type = 'SigmoidOperator'
1002+
1003+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
1004+ return DeclareOpSnippet (
1005+ op = self ,
1006+ templ_dtypes = [self .in_dtypes [0 ]],
1007+ op_var_name = op_var_name ,
1008+ )
1009+
1010+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
1011+ return SigmoidOpEvalSnippet (
1012+ op_info = op_info ,
1013+ templ_dtypes = [self .in_dtypes [0 ]],
1014+ op_name = op_var_name ,
1015+ tensor_var_map = tensor_var_map ,
1016+ )
0 commit comments