Skip to content

Commit e8a7ddf

Browse files
committed
Lower to torch dialect without expansion
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 386bba4 commit e8a7ddf

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16061606
/* cudnn enabled */ boolFalse);
16071607
return success();
16081608
});
1609+
patterns.onOp(
1610+
"MeanVarianceNormalization", 13,
1611+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
1612+
Torch::ValueTensorType resultType;
1613+
Value input;
1614+
SmallVector<int64_t> axes;
1615+
1616+
if (binder.tensorOperand(input) ||
1617+
binder.s64IntegerArrayAttr(axes, "axes",
1618+
llvm::SmallVector<int64_t>({0, 2, 3})) ||
1619+
binder.tensorResultType(resultType)) {
1620+
return failure();
1621+
}
1622+
Location loc = binder.getLoc();
1623+
Value keepDim = rewriter.create<Torch::ConstantBoolOp>(loc, true);
1624+
Value unBiased = rewriter.create<Torch::ConstantBoolOp>(loc, false);
1625+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
1626+
1627+
ArrayRef<int64_t> input_shape = resultType.getSizes();
1628+
SmallVector<int64_t> reduced_shape(input_shape);
1629+
for (int64_t i : axes) {
1630+
reduced_shape[i] = 1;
1631+
}
1632+
1633+
Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get(
1634+
resultType.getContext(), reduced_shape, resultType.getDtype());
1635+
SmallVector<Value> cstAxes;
1636+
for (int64_t i : axes) {
1637+
cstAxes.push_back(rewriter.create<Torch::ConstantIntOp>(
1638+
loc, rewriter.getI64IntegerAttr(i)));
1639+
}
1640+
Value axes_list = rewriter.create<Torch::PrimListConstructOp>(
1641+
loc,
1642+
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
1643+
cstAxes);
1644+
Value mean = rewriter.create<Torch::AtenMeanDimOp>(
1645+
loc, meanOutTy, input, axes_list, keepDim, none);
1646+
1647+
Value variance = rewriter.create<Torch::AtenVarDimOp>(
1648+
loc, meanOutTy, input, axes_list, unBiased, keepDim);
1649+
1650+
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
1651+
loc, rewriter.getI64IntegerAttr(1));
1652+
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
1653+
loc, rewriter.getF64FloatAttr(1e-9));
1654+
variance = rewriter.create<Torch::AtenAddScalarOp>(
1655+
loc, meanOutTy, variance, cstEps, cstOne);
1656+
1657+
Value sqrt =
1658+
rewriter.create<Torch::AtenSqrtOp>(loc, meanOutTy, variance);
1659+
1660+
Value subValue = rewriter.create<Torch::AtenSubTensorOp>(
1661+
loc, resultType, input, mean, cstOne);
1662+
1663+
Value meanVarNorm = rewriter.create<Torch::AtenDivTensorOp>(
1664+
loc, resultType, subValue, sqrt);
1665+
1666+
rewriter.replaceOp(binder.op, meanVarNorm);
1667+
1668+
return success();
1669+
});
16091670
patterns.onOp(
16101671
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
16111672
Torch::ValueTensorType resultType;

0 commit comments

Comments
 (0)