@@ -1619,18 +1619,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1619
1619
binder.tensorResultType (resultType)) {
1620
1620
return failure ();
1621
1621
}
1622
+ if (!resultType.hasSizes () || !resultType.hasDtype ()) {
1623
+ return failure ();
1624
+ }
1625
+ auto inputTy = cast<Torch::ValueTensorType>(input.getType ());
1626
+ if (!inputTy || !inputTy.hasSizes ()) {
1627
+ return failure ();
1628
+ }
1629
+ int64_t inputRank = inputTy.getSizes ().size ();
1630
+
1622
1631
Location loc = binder.getLoc ();
1623
1632
Value keepDim = rewriter.create <Torch::ConstantBoolOp>(loc, true );
1624
1633
Value unBiased = rewriter.create <Torch::ConstantBoolOp>(loc, false );
1625
1634
Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
1626
1635
1627
- ArrayRef<int64_t > input_shape = resultType.getSizes ();
1628
- SmallVector<int64_t > reduced_shape (input_shape);
1636
+ ArrayRef<int64_t > output_shape = resultType.getSizes ();
1637
+ SmallVector<int64_t > reduced_shape (output_shape);
1638
+
1629
1639
for (int64_t i : axes) {
1640
+ int64_t dim = Torch::toPositiveDim (i, inputRank);
1641
+ if (!Torch::isValidDim (dim, inputRank)) {
1642
+ return failure ();
1643
+ }
1630
1644
reduced_shape[i] = 1 ;
1631
1645
}
1632
-
1633
- Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get (
1646
+ Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get (
1634
1647
resultType.getContext (), reduced_shape, resultType.getDtype ());
1635
1648
SmallVector<Value> cstAxes;
1636
1649
for (int64_t i : axes) {
@@ -1642,29 +1655,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
1642
1655
Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
1643
1656
cstAxes);
1644
1657
Value mean = rewriter.create <Torch::AtenMeanDimOp>(
1645
- loc, meanOutTy, input, axes_list, keepDim, none);
1646
-
1658
+ loc, reducedOutTy, input, axes_list, keepDim, none);
1647
1659
Value variance = rewriter.create <Torch::AtenVarDimOp>(
1648
- loc, meanOutTy, input, axes_list, unBiased, keepDim);
1649
-
1660
+ loc, reducedOutTy, input, axes_list, unBiased, keepDim);
1650
1661
Value cstOne = rewriter.create <Torch::ConstantIntOp>(
1651
1662
loc, rewriter.getI64IntegerAttr (1 ));
1652
1663
Value cstEps = rewriter.create <Torch::ConstantFloatOp>(
1653
1664
loc, rewriter.getF64FloatAttr (1e-9 ));
1654
1665
variance = rewriter.create <Torch::AtenAddScalarOp>(
1655
- loc, meanOutTy, variance, cstEps, cstOne);
1656
-
1657
- Value sqrt =
1658
- rewriter.create <Torch::AtenSqrtOp>(loc, meanOutTy, variance);
1659
-
1660
- Value subValue = rewriter.create <Torch::AtenSubTensorOp>(
1666
+ loc, reducedOutTy, variance, cstEps, cstOne);
1667
+ Value sqrtVar =
1668
+ rewriter.create <Torch::AtenSqrtOp>(loc, reducedOutTy, variance);
1669
+ Value inputMinusMean = rewriter.create <Torch::AtenSubTensorOp>(
1661
1670
loc, resultType, input, mean, cstOne);
1662
-
1663
1671
Value meanVarNorm = rewriter.create <Torch::AtenDivTensorOp>(
1664
- loc, resultType, subValue, sqrt );
1672
+ loc, resultType, inputMinusMean, sqrtVar );
1665
1673
1666
1674
rewriter.replaceOp (binder.op , meanVarNorm);
1667
-
1668
1675
return success ();
1669
1676
});
1670
1677
patterns.onOp (
0 commit comments