Skip to content

Commit 8366790

Browse files
authored
[Stablehlo] Refactor utility functions for reduction (#4277)
1 parent 6930bf2 commit 8366790

File tree

1 file changed

+18
-35
lines changed

1 file changed

+18
-35
lines changed

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,85 +45,68 @@ static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
4545
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
4646
PatternRewriter &rewriter) {
4747
auto constType = RankedTensorType::get({}, elementTy);
48+
DenseElementsAttr constAttr = nullptr;
4849
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
4950
AtenLinalgVectorNormOp>(op)) {
5051
if (isa<mlir::FloatType>(elementTy)) {
51-
auto constAttr = DenseElementsAttr::get(
52+
constAttr = DenseElementsAttr::get(
5253
constType, {APFloat::getZero(
5354
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
5455
/*negative=*/false)});
55-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
56-
constAttr);
5756
} else if (isa<mlir::IntegerType>(elementTy)) {
58-
auto constAttr = DenseElementsAttr::get(
57+
constAttr = DenseElementsAttr::get(
5958
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
60-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
61-
constAttr);
6259
}
6360
}
6461

6562
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
6663
if (isa<mlir::FloatType>(elementTy)) {
67-
auto constAttr = DenseElementsAttr::get(
64+
constAttr = DenseElementsAttr::get(
6865
constType,
6966
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
7067
/*negative=*/true)});
71-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
72-
constAttr);
7368
} else if (isa<mlir::IntegerType>(elementTy)) {
74-
auto constAttr = DenseElementsAttr::get(
69+
constAttr = DenseElementsAttr::get(
7570
constType,
7671
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
77-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
78-
constAttr);
7972
}
8073
}
8174

8275
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
8376
if (isa<mlir::FloatType>(elementTy)) {
84-
auto constAttr = DenseElementsAttr::get(
77+
constAttr = DenseElementsAttr::get(
8578
constType,
8679
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
8780
/*negative=*/false)});
88-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
89-
constAttr);
9081
} else if (isa<mlir::IntegerType>(elementTy)) {
91-
auto constAttr = DenseElementsAttr::get(
82+
constAttr = DenseElementsAttr::get(
9283
constType,
9384
{APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())});
94-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
95-
constAttr);
9685
}
9786
}
9887

9988
if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
10089
if (isa<mlir::FloatType>(elementTy)) {
10190
APFloat one(cast<mlir::FloatType>(elementTy).getFloatSemantics(), 1);
102-
auto constAttr = DenseElementsAttr::get(constType, one);
103-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
104-
constAttr);
91+
constAttr = DenseElementsAttr::get(constType, one);
10592
} else if (isa<mlir::IntegerType>(elementTy)) {
10693
APInt one(elementTy.getIntOrFloatBitWidth(), 1);
107-
auto constAttr = DenseElementsAttr::get(constType, one);
108-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
109-
constAttr);
94+
constAttr = DenseElementsAttr::get(constType, one);
11095
}
11196
}
11297

11398
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
114-
auto constAttr =
115-
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
116-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
117-
constAttr);
99+
constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
118100
}
119101

120102
if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp>(op)) {
121-
auto constAttr =
122-
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
103+
constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
104+
}
105+
106+
if (constAttr != nullptr) {
123107
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
124108
constAttr);
125109
}
126-
127110
op->emitError("unimplemented lowering in "
128111
"createInitialValueForReduceOp");
129112
return nullptr;
@@ -483,7 +466,7 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
483466
return rewriter.notifyMatchFailure(
484467
op, "non-const integer `dim` is not supported");
485468
}
486-
if (inputDims.size() == 0) {
469+
if (inputDims.empty()) {
487470
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
488471
} else {
489472
for (auto d : inputDims) {
@@ -570,7 +553,7 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
570553
return rewriter.notifyMatchFailure(
571554
op, "failed to get dimension sizes of the input");
572555
}
573-
auto inputShapeVec = *inputShapeInfo;
556+
auto &inputShapeVec = *inputShapeInfo;
574557

575558
if (op.getResult(1).use_empty()) {
576559
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
@@ -643,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
643626
return rewriter.notifyMatchFailure(
644627
op, "non-const integer `dim` is not supported");
645628
}
646-
if (inputDims.size() == 0) {
629+
if (inputDims.empty()) {
647630
rewriter.replaceOp(op, input);
648631
return success();
649632
}
@@ -722,7 +705,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
722705
return rewriter.notifyMatchFailure(
723706
op, "non-const integer `dim` is not supported");
724707
}
725-
if (inputDims.size() == 0) {
708+
if (inputDims.empty()) {
726709
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
727710
}
728711
}

0 commit comments

Comments
 (0)