@@ -436,6 +436,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
436436 tensor_var_map = tensor_var_map ,
437437 )
438438
439+ @OperatorFactory .register
440+ class _ConvOperator (_CommonParams ):
441+ op_type = "ConvOperator"
442+
443+ @classmethod
444+ @must_return_type (Hashable )
445+ def get_constructor_parameters (cls , op_info ):
446+
447+ strides = [
448+ 1 ,
449+ op_info .op_attr ['StrideW' ],
450+ op_info .op_attr ['StrideH' ],
451+ 1 ,
452+ ]
453+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
454+ strides_str = ',' .join (map (str , strides ))
455+ return ("{{ {} }}" .format (strides_str ), padding )
456+
457+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
458+ return DeclareOpSnippet (
459+ op = self ,
460+ templ_dtypes = [self .out_dtypes [0 ]],
461+ op_var_name = op_var_name ,
462+ )
463+
464+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
465+ return ConvOpEvalSnippet (
466+ op_info = op_info ,
467+ templ_dtypes = [self .out_dtypes [0 ]],
468+ op_name = op_var_name ,
469+ tensor_var_map = tensor_var_map ,
470+ )
471+
439472
440473@OperatorFactory .register
441474class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -464,3 +497,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
464497 op_name = op_var_name ,
465498 tensor_var_map = tensor_var_map ,
466499 )
500+
501+ @OperatorFactory .register
502+ class _BatchNormOperator (_CommonParams ):
503+ op_type = "BatchNormOperator"
504+
505+ @classmethod
506+ @must_return_type (Hashable )
507+ def get_constructor_parameters (cls , op_info ):
508+ strides = [
509+ 1 ,
510+ op_info .op_attr ['StrideW' ],
511+ op_info .op_attr ['StrideH' ],
512+ 1 ,
513+ ]
514+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
515+ strides_str = ',' .join (map (str , strides ))
516+ return ("{{ {} }}" .format (strides_str ), padding )
517+
518+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
519+ return DeclareOpSnippet (
520+ op = self ,
521+ templ_dtypes = [self .out_dtypes [0 ]],
522+ op_var_name = op_var_name ,
523+ )
524+
525+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
526+ return BatchNormSnippet (
527+ op_info = op_info ,
528+ templ_dtypes = [self .out_dtypes [0 ]],
529+ op_name = op_var_name ,
530+ tensor_var_map = tensor_var_map ,
531+ )
532+
533+ @OperatorFactory .register
534+ class _MeanOperator (_CommonParams ):
535+ op_type = "MeanOperator"
536+
537+ @classmethod
538+ @must_return_type (Hashable )
539+ def get_constructor_parameters (cls , op_info ):
540+ keep_dims = str (op_info .op_attr ["keep_dims" ])
541+ return (" {} " .format (keep_dims ), )
542+
543+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
544+ return DeclareOpSnippet (
545+ op = self ,
546+ templ_dtypes = [self .out_dtypes [0 ]],
547+ op_var_name = op_var_name ,
548+ )
549+
550+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
551+ return BatchNormSnippet (
552+ op_info = op_info ,
553+ templ_dtypes = [self .out_dtypes [0 ]],
554+ op_name = op_var_name ,
555+ tensor_var_map = tensor_var_map ,
556+ )
557+
558+ @OperatorFactory .register
559+ class _SoftmaxOperator (_CommonParams ):
560+ op_type = "SoftmaxOperator"
561+
562+ @classmethod
563+ @must_return_type (Hashable )
564+ def get_constructor_parameters (cls , op_info ):
565+ Beta = op_info .op_attr ["Beta" ]
566+ return (" %f " % Beta ,)
567+
568+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
569+ return DeclareOpSnippet (
570+ op = self ,
571+ templ_dtypes = [self .out_dtypes [0 ]],
572+ op_var_name = op_var_name ,
573+ )
574+
575+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
576+ return BatchNormSnippet (
577+ op_info = op_info ,
578+ templ_dtypes = [self .out_dtypes [0 ]],
579+ op_name = op_var_name ,
580+ tensor_var_map = tensor_var_map ,
581+ )
582+
583+ @OperatorFactory .register
584+ class _MulOperator (_Operator ):
585+ op_type = 'MulOperator'
586+
587+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
588+ return DeclareOpSnippet (
589+ op = self ,
590+ templ_dtypes = [self .in_dtypes [0 ]],
591+ op_var_name = op_var_name ,
592+ )
593+
594+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
595+ return MulOpEvalSnippet (
596+ op_info = op_info ,
597+ templ_dtypes = [self .in_dtypes [0 ]],
598+ op_name = op_var_name ,
599+ tensor_var_map = tensor_var_map ,
600+ )
601+
602+ @OperatorFactory .register
603+ class _SubOperator (_Operator ):
604+ op_type = 'SubOperator'
605+
606+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
607+ return DeclareOpSnippet (
608+ op = self ,
609+ templ_dtypes = [self .in_dtypes [0 ]],
610+ op_var_name = op_var_name ,
611+ )
612+
613+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
614+ return SubOpEvalSnippet (
615+ op_info = op_info ,
616+ templ_dtypes = [self .in_dtypes [0 ]],
617+ op_name = op_var_name ,
618+ tensor_var_map = tensor_var_map ,
619+ )
0 commit comments