@@ -45,85 +45,68 @@ static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
45
45
static Value createInitialValueForReduceOp (Operation *op, Type elementTy,
46
46
PatternRewriter &rewriter) {
47
47
auto constType = RankedTensorType::get ({}, elementTy);
48
+ DenseElementsAttr constAttr = nullptr ;
48
49
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
49
50
AtenLinalgVectorNormOp>(op)) {
50
51
if (isa<mlir::FloatType>(elementTy)) {
51
- auto constAttr = DenseElementsAttr::get (
52
+ constAttr = DenseElementsAttr::get (
52
53
constType, {APFloat::getZero (
53
54
cast<mlir::FloatType>(elementTy).getFloatSemantics (),
54
55
/* negative=*/ false )});
55
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
56
- constAttr);
57
56
} else if (isa<mlir::IntegerType>(elementTy)) {
58
- auto constAttr = DenseElementsAttr::get (
57
+ constAttr = DenseElementsAttr::get (
59
58
constType, {APInt::getZero (elementTy.getIntOrFloatBitWidth ())});
60
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
61
- constAttr);
62
59
}
63
60
}
64
61
65
62
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
66
63
if (isa<mlir::FloatType>(elementTy)) {
67
- auto constAttr = DenseElementsAttr::get (
64
+ constAttr = DenseElementsAttr::get (
68
65
constType,
69
66
{APFloat::getInf (cast<mlir::FloatType>(elementTy).getFloatSemantics (),
70
67
/* negative=*/ true )});
71
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
72
- constAttr);
73
68
} else if (isa<mlir::IntegerType>(elementTy)) {
74
- auto constAttr = DenseElementsAttr::get (
69
+ constAttr = DenseElementsAttr::get (
75
70
constType,
76
71
{APInt::getSignedMinValue (elementTy.getIntOrFloatBitWidth ())});
77
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
78
- constAttr);
79
72
}
80
73
}
81
74
82
75
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
83
76
if (isa<mlir::FloatType>(elementTy)) {
84
- auto constAttr = DenseElementsAttr::get (
77
+ constAttr = DenseElementsAttr::get (
85
78
constType,
86
79
{APFloat::getInf (cast<mlir::FloatType>(elementTy).getFloatSemantics (),
87
80
/* negative=*/ false )});
88
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
89
- constAttr);
90
81
} else if (isa<mlir::IntegerType>(elementTy)) {
91
- auto constAttr = DenseElementsAttr::get (
82
+ constAttr = DenseElementsAttr::get (
92
83
constType,
93
84
{APInt::getSignedMaxValue (elementTy.getIntOrFloatBitWidth ())});
94
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
95
- constAttr);
96
85
}
97
86
}
98
87
99
88
if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
100
89
if (isa<mlir::FloatType>(elementTy)) {
101
90
APFloat one (cast<mlir::FloatType>(elementTy).getFloatSemantics (), 1 );
102
- auto constAttr = DenseElementsAttr::get (constType, one);
103
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
104
- constAttr);
91
+ constAttr = DenseElementsAttr::get (constType, one);
105
92
} else if (isa<mlir::IntegerType>(elementTy)) {
106
93
APInt one (elementTy.getIntOrFloatBitWidth (), 1 );
107
- auto constAttr = DenseElementsAttr::get (constType, one);
108
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
109
- constAttr);
94
+ constAttr = DenseElementsAttr::get (constType, one);
110
95
}
111
96
}
112
97
113
98
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
114
- auto constAttr =
115
- DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 1 )});
116
- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
117
- constAttr);
99
+ constAttr = DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 1 )});
118
100
}
119
101
120
102
if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp>(op)) {
121
- auto constAttr =
122
- DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 0 )});
103
+ constAttr = DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 0 )});
104
+ }
105
+
106
+ if (constAttr != nullptr ) {
123
107
return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
124
108
constAttr);
125
109
}
126
-
127
110
op->emitError (" unimplemented lowering in "
128
111
" createInitialValueForReduceOp" );
129
112
return nullptr ;
@@ -483,7 +466,7 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
483
466
return rewriter.notifyMatchFailure (
484
467
op, " non-const integer `dim` is not supported" );
485
468
}
486
- if (inputDims.size () == 0 ) {
469
+ if (inputDims.empty () ) {
487
470
dims = llvm::to_vector (llvm::seq<int64_t >(0 , inputTy.getRank ()));
488
471
} else {
489
472
for (auto d : inputDims) {
@@ -570,7 +553,7 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
570
553
return rewriter.notifyMatchFailure (
571
554
op, " failed to get dimension sizes of the input" );
572
555
}
573
- auto inputShapeVec = *inputShapeInfo;
556
+ auto & inputShapeVec = *inputShapeInfo;
574
557
575
558
if (op.getResult (1 ).use_empty ()) {
576
559
llvm::SmallVector<int64_t > outputShape (inputTy.getShape ());
@@ -643,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
643
626
return rewriter.notifyMatchFailure (
644
627
op, " non-const integer `dim` is not supported" );
645
628
}
646
- if (inputDims.size () == 0 ) {
629
+ if (inputDims.empty () ) {
647
630
rewriter.replaceOp (op, input);
648
631
return success ();
649
632
}
@@ -722,7 +705,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
722
705
return rewriter.notifyMatchFailure (
723
706
op, " non-const integer `dim` is not supported" );
724
707
}
725
- if (inputDims.size () == 0 ) {
708
+ if (inputDims.empty () ) {
726
709
inputDims = llvm::to_vector<4 >(llvm::seq<int64_t >(0 , inputTy.getRank ()));
727
710
}
728
711
}
0 commit comments