@@ -405,6 +405,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
405405 tensor_var_map = tensor_var_map ,
406406 )
407407
408+ @OperatorFactory .register
409+ class _ConvOperator (_CommonParams ):
410+ op_type = "ConvOperator"
411+
412+ @classmethod
413+ @must_return_type (Hashable )
414+ def get_constructor_parameters (cls , op_info ):
415+
416+ strides = [
417+ 1 ,
418+ op_info .op_attr ['StrideW' ],
419+ op_info .op_attr ['StrideH' ],
420+ 1 ,
421+ ]
422+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
423+ strides_str = ',' .join (map (str , strides ))
424+ return ("{{ {} }}" .format (strides_str ), padding )
425+
426+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
427+ return DeclareOpSnippet (
428+ op = self ,
429+ templ_dtypes = [self .out_dtypes [0 ]],
430+ op_var_name = op_var_name ,
431+ )
432+
433+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
434+ return ConvOpEvalSnippet (
435+ op_info = op_info ,
436+ templ_dtypes = [self .out_dtypes [0 ]],
437+ op_name = op_var_name ,
438+ tensor_var_map = tensor_var_map ,
439+ )
440+
408441
409442@OperatorFactory .register
410443class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -433,3 +466,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
433466 op_name = op_var_name ,
434467 tensor_var_map = tensor_var_map ,
435468 )
469+
470+ @OperatorFactory .register
471+ class _BatchNormOperator (_CommonParams ):
472+ op_type = "BatchNormOperator"
473+
474+ @classmethod
475+ @must_return_type (Hashable )
476+ def get_constructor_parameters (cls , op_info ):
477+ strides = [
478+ 1 ,
479+ op_info .op_attr ['StrideW' ],
480+ op_info .op_attr ['StrideH' ],
481+ 1 ,
482+ ]
483+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
484+ strides_str = ',' .join (map (str , strides ))
485+ return ("{{ {} }}" .format (strides_str ), padding )
486+
487+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
488+ return DeclareOpSnippet (
489+ op = self ,
490+ templ_dtypes = [self .out_dtypes [0 ]],
491+ op_var_name = op_var_name ,
492+ )
493+
494+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
495+ return BatchNormSnippet (
496+ op_info = op_info ,
497+ templ_dtypes = [self .out_dtypes [0 ]],
498+ op_name = op_var_name ,
499+ tensor_var_map = tensor_var_map ,
500+ )
501+
502+ @OperatorFactory .register
503+ class _MeanOperator (_CommonParams ):
504+ op_type = "MeanOperator"
505+
506+ @classmethod
507+ @must_return_type (Hashable )
508+ def get_constructor_parameters (cls , op_info ):
509+ keep_dims = str (op_info .op_attr ["keep_dims" ])
510+ return (" {} " .format (keep_dims ), )
511+
512+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
513+ return DeclareOpSnippet (
514+ op = self ,
515+ templ_dtypes = [self .out_dtypes [0 ]],
516+ op_var_name = op_var_name ,
517+ )
518+
519+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
520+ return BatchNormSnippet (
521+ op_info = op_info ,
522+ templ_dtypes = [self .out_dtypes [0 ]],
523+ op_name = op_var_name ,
524+ tensor_var_map = tensor_var_map ,
525+ )
526+
527+ @OperatorFactory .register
528+ class _SoftmaxOperator (_CommonParams ):
529+ op_type = "SoftmaxOperator"
530+
531+ @classmethod
532+ @must_return_type (Hashable )
533+ def get_constructor_parameters (cls , op_info ):
534+ Beta = op_info .op_attr ["Beta" ]
535+ return (" %f " % Beta ,)
536+
537+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
538+ return DeclareOpSnippet (
539+ op = self ,
540+ templ_dtypes = [self .out_dtypes [0 ]],
541+ op_var_name = op_var_name ,
542+ )
543+
544+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
545+ return BatchNormSnippet (
546+ op_info = op_info ,
547+ templ_dtypes = [self .out_dtypes [0 ]],
548+ op_name = op_var_name ,
549+ tensor_var_map = tensor_var_map ,
550+ )
551+
552+ @OperatorFactory .register
553+ class _MulOperator (_Operator ):
554+ op_type = 'MulOperator'
555+
556+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
557+ return DeclareOpSnippet (
558+ op = self ,
559+ templ_dtypes = [self .in_dtypes [0 ]],
560+ op_var_name = op_var_name ,
561+ )
562+
563+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
564+ return MulOpEvalSnippet (
565+ op_info = op_info ,
566+ templ_dtypes = [self .in_dtypes [0 ]],
567+ op_name = op_var_name ,
568+ tensor_var_map = tensor_var_map ,
569+ )
570+
571+ @OperatorFactory .register
572+ class _SubOperator (_Operator ):
573+ op_type = 'SubOperator'
574+
575+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
576+ return DeclareOpSnippet (
577+ op = self ,
578+ templ_dtypes = [self .in_dtypes [0 ]],
579+ op_var_name = op_var_name ,
580+ )
581+
582+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
583+ return SubOpEvalSnippet (
584+ op_info = op_info ,
585+ templ_dtypes = [self .in_dtypes [0 ]],
586+ op_name = op_var_name ,
587+ tensor_var_map = tensor_var_map ,
588+ )
0 commit comments