@@ -1606,6 +1606,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1606
1606
/* cudnn enabled */ boolFalse);
1607
1607
return success ();
1608
1608
});
1609
+ patterns.onOp (
1610
+ " MeanVarianceNormalization" , 13 ,
1611
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1612
+ Torch::ValueTensorType resultType;
1613
+ Value input;
1614
+ SmallVector<int64_t > axes;
1615
+
1616
+ if (binder.tensorOperand (input) ||
1617
+ binder.s64IntegerArrayAttr (axes, " axes" ,
1618
+ llvm::SmallVector<int64_t >({0 , 2 , 3 })) ||
1619
+ binder.tensorResultType (resultType)) {
1620
+ return failure ();
1621
+ }
1622
+ Location loc = binder.getLoc ();
1623
+ Value keepDim = rewriter.create <Torch::ConstantBoolOp>(loc, true );
1624
+ Value unBiased = rewriter.create <Torch::ConstantBoolOp>(loc, false );
1625
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1626
+
1627
+ ArrayRef<int64_t > input_shape = resultType.getSizes ();
1628
+ SmallVector<int64_t > reduced_shape (input_shape);
1629
+ for (int64_t i : axes) {
1630
+ reduced_shape[i] = 1 ;
1631
+ }
1632
+
1633
+ Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get (
1634
+ resultType.getContext (), reduced_shape, resultType.getDtype ());
1635
+ SmallVector<Value> cstAxes;
1636
+ for (int64_t i : axes) {
1637
+ cstAxes.push_back (rewriter.create <Torch::ConstantIntOp>(
1638
+ loc, rewriter.getI64IntegerAttr (i)));
1639
+ }
1640
+ Value axes_list = rewriter.create <Torch::PrimListConstructOp>(
1641
+ loc,
1642
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1643
+ cstAxes);
1644
+ Value mean = rewriter.create <Torch::AtenMeanDimOp>(
1645
+ loc, meanOutTy, input, axes_list, keepDim, none);
1646
+
1647
+ Value variance = rewriter.create <Torch::AtenVarDimOp>(
1648
+ loc, meanOutTy, input, axes_list, unBiased, keepDim);
1649
+
1650
+ Value cstOne = rewriter.create <Torch::ConstantIntOp>(
1651
+ loc, rewriter.getI64IntegerAttr (1 ));
1652
+ Value cstEps = rewriter.create <Torch::ConstantFloatOp>(
1653
+ loc, rewriter.getF64FloatAttr (1e-9 ));
1654
+ variance = rewriter.create <Torch::AtenAddScalarOp>(
1655
+ loc, meanOutTy, variance, cstEps, cstOne);
1656
+
1657
+ Value sqrt =
1658
+ rewriter.create <Torch::AtenSqrtOp>(loc, meanOutTy, variance);
1659
+
1660
+ Value subValue = rewriter.create <Torch::AtenSubTensorOp>(
1661
+ loc, resultType, input, mean, cstOne);
1662
+
1663
+ Value meanVarNorm = rewriter.create <Torch::AtenDivTensorOp>(
1664
+ loc, resultType, subValue, sqrt);
1665
+
1666
+ rewriter.replaceOp (binder.op , meanVarNorm);
1667
+
1668
+ return success ();
1669
+ });
1609
1670
patterns.onOp (
1610
1671
" Max" , 1 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
1611
1672
Torch::ValueTensorType resultType;
0 commit comments