Skip to content

Commit c9fe057

Browse files
committed
Add type checks & allow MVN expansion by default
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent fb80dc1 commit c9fe057

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,18 +1619,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16191619
binder.tensorResultType(resultType)) {
16201620
return failure();
16211621
}
1622+
if (!resultType.hasSizes() || !resultType.hasDtype()) {
1623+
return failure();
1624+
}
1625+
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
1626+
if (!inputTy || !inputTy.hasSizes()) {
1627+
return failure();
1628+
}
1629+
int64_t inputRank = inputTy.getSizes().size();
1630+
16221631
Location loc = binder.getLoc();
16231632
Value keepDim = rewriter.create<Torch::ConstantBoolOp>(loc, true);
16241633
Value unBiased = rewriter.create<Torch::ConstantBoolOp>(loc, false);
16251634
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
16261635

1627-
ArrayRef<int64_t> input_shape = resultType.getSizes();
1628-
SmallVector<int64_t> reduced_shape(input_shape);
1636+
ArrayRef<int64_t> output_shape = resultType.getSizes();
1637+
SmallVector<int64_t> reduced_shape(output_shape);
1638+
16291639
for (int64_t i : axes) {
1640+
int64_t dim = Torch::toPositiveDim(i, inputRank);
1641+
if (!Torch::isValidDim(dim, inputRank)) {
1642+
return failure();
1643+
}
16301644
reduced_shape[i] = 1;
16311645
}
1632-
1633-
Torch::ValueTensorType meanOutTy = Torch::ValueTensorType::get(
1646+
Torch::ValueTensorType reducedOutTy = Torch::ValueTensorType::get(
16341647
resultType.getContext(), reduced_shape, resultType.getDtype());
16351648
SmallVector<Value> cstAxes;
16361649
for (int64_t i : axes) {
@@ -1642,29 +1655,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
16421655
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
16431656
cstAxes);
16441657
Value mean = rewriter.create<Torch::AtenMeanDimOp>(
1645-
loc, meanOutTy, input, axes_list, keepDim, none);
1646-
1658+
loc, reducedOutTy, input, axes_list, keepDim, none);
16471659
Value variance = rewriter.create<Torch::AtenVarDimOp>(
1648-
loc, meanOutTy, input, axes_list, unBiased, keepDim);
1649-
1660+
loc, reducedOutTy, input, axes_list, unBiased, keepDim);
16501661
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
16511662
loc, rewriter.getI64IntegerAttr(1));
16521663
Value cstEps = rewriter.create<Torch::ConstantFloatOp>(
16531664
loc, rewriter.getF64FloatAttr(1e-9));
16541665
variance = rewriter.create<Torch::AtenAddScalarOp>(
1655-
loc, meanOutTy, variance, cstEps, cstOne);
1656-
1657-
Value sqrt =
1658-
rewriter.create<Torch::AtenSqrtOp>(loc, meanOutTy, variance);
1659-
1660-
Value subValue = rewriter.create<Torch::AtenSubTensorOp>(
1666+
loc, reducedOutTy, variance, cstEps, cstOne);
1667+
Value sqrtVar =
1668+
rewriter.create<Torch::AtenSqrtOp>(loc, reducedOutTy, variance);
1669+
Value inputMinusMean = rewriter.create<Torch::AtenSubTensorOp>(
16611670
loc, resultType, input, mean, cstOne);
1662-
16631671
Value meanVarNorm = rewriter.create<Torch::AtenDivTensorOp>(
1664-
loc, resultType, subValue, sqrt);
1672+
loc, resultType, inputMinusMean, sqrtVar);
16651673

16661674
rewriter.replaceOp(binder.op, meanVarNorm);
1667-
16681675
return success();
16691676
});
16701677
patterns.onOp(

python/torch_mlir/extras/onnx_importer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ class Config:
103103
function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field(
104104
default_factory=lambda: {
105105
# Default domain (ONNX built-in ops)
106-
"": {}
106+
"": {
107+
"MeanVarianceNormalization",
108+
}
107109
}
108110
)
109111

0 commit comments

Comments
 (0)