Skip to content

Commit 111dfda

Browse files
authored
Merge branch 'llvm:main' into pixel.unshuffle
2 parents f409948 + 46c3888 commit 111dfda

File tree

10 files changed

+277
-36
lines changed

10 files changed

+277
-36
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10340,6 +10340,30 @@ def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [
1034010340
}];
1034110341
}
1034210342

10343+
def Torch_AtenReplicationPad3dOp : Torch_Op<"aten.replication_pad3d", [
10344+
AllowsTypeRefinement,
10345+
HasValueSemantics,
10346+
ReadOnly
10347+
]> {
10348+
let summary = "Generated op for `aten::replication_pad3d : (Tensor, int[]) -> (Tensor)`";
10349+
let arguments = (ins
10350+
AnyTorchTensorType:$self,
10351+
AnyTorchListOfTorchIntType:$padding
10352+
);
10353+
let results = (outs
10354+
AnyTorchOptionalTensorType:$result
10355+
);
10356+
let hasCustomAssemblyFormat = 1;
10357+
let extraClassDefinition = [{
10358+
ParseResult AtenReplicationPad3dOp::parse(OpAsmParser &parser, OperationState &result) {
10359+
return parseDefaultTorchOp(parser, result, 2, 1);
10360+
}
10361+
void AtenReplicationPad3dOp::print(OpAsmPrinter &printer) {
10362+
printDefaultTorchOp(printer, *this, 2, 1);
10363+
}
10364+
}];
10365+
}
10366+
1034310367
def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
1034410368
AllowsTypeRefinement,
1034510369
HasValueSemantics,

lib/Conversion/TorchToLinalg/TensorConstructors.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,107 @@ class ConvertAtenReplicationPad2dOp
426426
};
427427
} // namespace
428428

429+
namespace {
430+
431+
// Lower aten.replication_pad3d operator into a sequence of
432+
// tensor.extract_slice and tensor.concat operations.
433+
class ConvertAtenReplicationPad3dOp
434+
: public OpConversionPattern<AtenReplicationPad3dOp> {
435+
436+
private:
437+
enum sliceLoc { START = 0, END = 1 };
438+
439+
Value extractSlice(ConversionPatternRewriter &rewriter, Location loc,
440+
Value input, int64_t dimension, sliceLoc sliceLoc) const {
441+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
442+
int64_t inputRank = inputType.getRank();
443+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
444+
445+
SmallVector<OpFoldResult> offsets(inputRank, rewriter.getIndexAttr(0));
446+
if (sliceLoc == END) {
447+
Value dimSize = inputShape[dimension];
448+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
449+
Value endIdx = rewriter.create<arith::SubIOp>(loc, dimSize, one);
450+
offsets[dimension] = getAsOpFoldResult(endIdx);
451+
}
452+
453+
SmallVector<OpFoldResult> allOneStrides(inputRank,
454+
rewriter.getIndexAttr(1));
455+
SmallVector<OpFoldResult> sizes(inputRank, rewriter.getIndexAttr(0));
456+
for (int i = 0; i < inputRank; ++i)
457+
sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1)
458+
: getAsOpFoldResult(inputShape[i]);
459+
460+
Value extractedSlice = rewriter.create<tensor::ExtractSliceOp>(
461+
loc, input, offsets, sizes, allOneStrides);
462+
return extractedSlice;
463+
}
464+
465+
Value createTile(ConversionPatternRewriter &rewriter, Location loc,
466+
Value slice, int64_t tileWidth, int64_t dimension) const {
467+
SmallVector<Value> slices(tileWidth, slice);
468+
if (tileWidth == 1)
469+
return slice;
470+
return rewriter.create<tensor::ConcatOp>(loc, dimension, slices);
471+
}
472+
473+
public:
474+
using OpConversionPattern::OpConversionPattern;
475+
476+
LogicalResult
477+
matchAndRewrite(AtenReplicationPad3dOp op, OpAdaptor adaptor,
478+
ConversionPatternRewriter &rewriter) const override {
479+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
480+
return failure();
481+
482+
Location loc = op->getLoc();
483+
Value input = adaptor.getSelf();
484+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
485+
int64_t inputRank = inputType.getRank();
486+
unsigned numDims = inputType.getRank();
487+
assert(numDims >= 2 && "Not enough input dimensions");
488+
489+
SmallVector<int64_t> padInts;
490+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
491+
return rewriter.notifyMatchFailure(
492+
op, "only support constant int pad ranges");
493+
494+
if (padInts.size() != 6)
495+
return rewriter.notifyMatchFailure(
496+
op, "pad range must have exactly six values");
497+
498+
Value res = input;
499+
int64_t padIdx = 0;
500+
for (int64_t dim = inputRank - 1; dim >= inputRank - 3; dim--) {
501+
int64_t startTileWidth = padInts[padIdx++];
502+
int64_t endTileWidth = padInts[padIdx++];
503+
504+
SmallVector<Value> resultParts;
505+
if (startTileWidth > 0) {
506+
Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::START);
507+
Value tile = createTile(rewriter, loc, slice, startTileWidth, dim);
508+
resultParts.push_back(tile);
509+
}
510+
511+
resultParts.push_back(res);
512+
513+
if (endTileWidth > 0) {
514+
Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::END);
515+
Value tile = createTile(rewriter, loc, slice, endTileWidth, dim);
516+
resultParts.push_back(tile);
517+
}
518+
519+
if (resultParts.size() > 1)
520+
res = rewriter.create<tensor::ConcatOp>(loc, dim, resultParts);
521+
}
522+
523+
Type resultType = getTypeConverter()->convertType(op.getType());
524+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, res);
525+
return success();
526+
}
527+
};
528+
529+
} // namespace
429530
namespace {
430531
// Converts constant tensor allocation like ops.
431532
template <typename OpTy, int fillVal>
@@ -696,6 +797,8 @@ void mlir::torch::torch_to_linalg::
696797
RewritePatternSet &patterns,
697798
ConversionTarget &target) {
698799
MLIRContext *context = patterns.getContext();
800+
target.addIllegalOp<AtenReplicationPad3dOp>();
801+
patterns.add<ConvertAtenReplicationPad3dOp>(typeConverter, context);
699802
target.addIllegalOp<AtenReplicationPad2dOp>();
700803
patterns.add<ConvertAtenReplicationPad2dOp>(typeConverter, context);
701804
target.addIllegalOp<AtenReplicationPad1dOp>();

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,85 +45,68 @@ static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
4545
static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
4646
PatternRewriter &rewriter) {
4747
auto constType = RankedTensorType::get({}, elementTy);
48+
DenseElementsAttr constAttr = nullptr;
4849
if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
4950
AtenLinalgVectorNormOp>(op)) {
5051
if (isa<mlir::FloatType>(elementTy)) {
51-
auto constAttr = DenseElementsAttr::get(
52+
constAttr = DenseElementsAttr::get(
5253
constType, {APFloat::getZero(
5354
cast<mlir::FloatType>(elementTy).getFloatSemantics(),
5455
/*negative=*/false)});
55-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
56-
constAttr);
5756
} else if (isa<mlir::IntegerType>(elementTy)) {
58-
auto constAttr = DenseElementsAttr::get(
57+
constAttr = DenseElementsAttr::get(
5958
constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())});
60-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
61-
constAttr);
6259
}
6360
}
6461

6562
if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
6663
if (isa<mlir::FloatType>(elementTy)) {
67-
auto constAttr = DenseElementsAttr::get(
64+
constAttr = DenseElementsAttr::get(
6865
constType,
6966
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
7067
/*negative=*/true)});
71-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
72-
constAttr);
7368
} else if (isa<mlir::IntegerType>(elementTy)) {
74-
auto constAttr = DenseElementsAttr::get(
69+
constAttr = DenseElementsAttr::get(
7570
constType,
7671
{APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())});
77-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
78-
constAttr);
7972
}
8073
}
8174

8275
if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
8376
if (isa<mlir::FloatType>(elementTy)) {
84-
auto constAttr = DenseElementsAttr::get(
77+
constAttr = DenseElementsAttr::get(
8578
constType,
8679
{APFloat::getInf(cast<mlir::FloatType>(elementTy).getFloatSemantics(),
8780
/*negative=*/false)});
88-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
89-
constAttr);
9081
} else if (isa<mlir::IntegerType>(elementTy)) {
91-
auto constAttr = DenseElementsAttr::get(
82+
constAttr = DenseElementsAttr::get(
9283
constType,
9384
{APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())});
94-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
95-
constAttr);
9685
}
9786
}
9887

9988
if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
10089
if (isa<mlir::FloatType>(elementTy)) {
10190
APFloat one(cast<mlir::FloatType>(elementTy).getFloatSemantics(), 1);
102-
auto constAttr = DenseElementsAttr::get(constType, one);
103-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
104-
constAttr);
91+
constAttr = DenseElementsAttr::get(constType, one);
10592
} else if (isa<mlir::IntegerType>(elementTy)) {
10693
APInt one(elementTy.getIntOrFloatBitWidth(), 1);
107-
auto constAttr = DenseElementsAttr::get(constType, one);
108-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
109-
constAttr);
94+
constAttr = DenseElementsAttr::get(constType, one);
11095
}
11196
}
11297

11398
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
114-
auto constAttr =
115-
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
116-
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
117-
constAttr);
99+
constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
118100
}
119101

120102
if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp>(op)) {
121-
auto constAttr =
122-
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
103+
constAttr = DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)});
104+
}
105+
106+
if (constAttr != nullptr) {
123107
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
124108
constAttr);
125109
}
126-
127110
op->emitError("unimplemented lowering in "
128111
"createInitialValueForReduceOp");
129112
return nullptr;
@@ -483,7 +466,7 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
483466
return rewriter.notifyMatchFailure(
484467
op, "non-const integer `dim` is not supported");
485468
}
486-
if (inputDims.size() == 0) {
469+
if (inputDims.empty()) {
487470
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
488471
} else {
489472
for (auto d : inputDims) {
@@ -570,7 +553,7 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
570553
return rewriter.notifyMatchFailure(
571554
op, "failed to get dimension sizes of the input");
572555
}
573-
auto inputShapeVec = *inputShapeInfo;
556+
auto &inputShapeVec = *inputShapeInfo;
574557

575558
if (op.getResult(1).use_empty()) {
576559
llvm::SmallVector<int64_t> outputShape(inputTy.getShape());
@@ -643,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
643626
return rewriter.notifyMatchFailure(
644627
op, "non-const integer `dim` is not supported");
645628
}
646-
if (inputDims.size() == 0) {
629+
if (inputDims.empty()) {
647630
rewriter.replaceOp(op, input);
648631
return success();
649632
}
@@ -722,7 +705,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
722705
return rewriter.notifyMatchFailure(
723706
op, "non-const integer `dim` is not supported");
724707
}
725-
if (inputDims.size() == 0) {
708+
if (inputDims.empty()) {
726709
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
727710
}
728711
}

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10971,6 +10971,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1097110971
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
1097210972
" return %4 : !torch.list<int>\n"
1097310973
" }\n"
10974+
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
10975+
" %false = torch.constant.bool false\n"
10976+
" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n"
10977+
" %none = torch.constant.none\n"
10978+
" %str_0 = torch.constant.str \"AssertionError: \"\n"
10979+
" %int3 = torch.constant.int 3\n"
10980+
" %int6 = torch.constant.int 6\n"
10981+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
10982+
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
10983+
" torch.prim.If %1 -> () {\n"
10984+
" torch.prim.If.yield\n"
10985+
" } else {\n"
10986+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
10987+
" torch.prim.If.yield\n"
10988+
" }\n"
10989+
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
10990+
" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n"
10991+
" torch.prim.If %3 -> () {\n"
10992+
" torch.prim.If.yield\n"
10993+
" } else {\n"
10994+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10995+
" torch.prim.If.yield\n"
10996+
" }\n"
10997+
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
10998+
" return %4 : !torch.list<int>\n"
10999+
" }\n"
1097411000
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
1097511001
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1097611002
" return %0#1 : !torch.int\n"
@@ -10979,6 +11005,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1097911005
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1098011006
" return %0#1 : !torch.int\n"
1098111007
" }\n"
11008+
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
11009+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11010+
" return %0#1 : !torch.int\n"
11011+
" }\n"
1098211012
" func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
1098311013
" %false = torch.constant.bool false\n"
1098411014
" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8708,9 +8708,13 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
87088708
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
87098709
op, op.getType(), op.getSelf(), usefulPads);
87108710
break;
8711+
case 3:
8712+
rewriter.replaceOpWithNewOp<AtenReplicationPad3dOp>(
8713+
op, op.getType(), op.getSelf(), usefulPads);
8714+
break;
87118715
default:
87128716
return rewriter.notifyMatchFailure(
8713-
op, "unsupported number of dims for 'reflect' mode: " +
8717+
op, "unsupported number of dims for 'replicate' mode: " +
87148718
std::to_string(numPadDims));
87158719
}
87168720
return success();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,8 @@
859859
"ReplicationPad2dModule_left0",
860860
"ReplicationPad2dModule_right0",
861861
"ReplicationPad2dModule_top0",
862+
"ReplicationPad3dModule_basic",
863+
"ReplicationPad3dModuleSingleIntPad_basic",
862864
"ScalarImplicitFloatModule_basic",
863865
# REMOVE WHEN ENABLE_GQA IS ADDED
864866
"ScatterAddDynamicModule_basic",
@@ -3959,6 +3961,8 @@
39593961
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
39603962
"ReplicationPad1dModule_2DInput_basic",
39613963
"ReplicationPad1dModule_3DInput_basic",
3964+
"ReplicationPad3dModule_basic",
3965+
"ReplicationPad3dModuleSingleIntPad_basic",
39623966
}
39633967

39643968
ONNX_TOSA_CRASHING_SET = {
@@ -4814,6 +4818,8 @@
48144818
"RMSNormDynamicModule_basic",
48154819
"ReplicationPad1dModule_2DInput_basic",
48164820
"ReplicationPad1dModule_3DInput_basic",
4821+
"ReplicationPad3dModule_basic",
4822+
"ReplicationPad3dModuleSingleIntPad_basic",
48174823
"RollModule_basic",
48184824
"RsubIntModule_noalpha_basic",
48194825
"ScalarConstantTupleModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,6 +2294,11 @@ def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> Lis
22942294
assert len(padding) == 4, 'padding size expected to be 4'
22952295
return pad_shape_fn(self, padding)
22962296

2297+
def aten〇replication_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]:
2298+
assert len(self) >= 3
2299+
assert len(padding) == 6, 'padding size expected to be 6'
2300+
return pad_shape_fn(self, padding)
2301+
22972302
def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
22982303
self_rank, self_dtype = self_rank_dtype
22992304
return self_dtype
@@ -2302,6 +2307,10 @@ def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding:
23022307
self_rank, self_dtype = self_rank_dtype
23032308
return self_dtype
23042309

2310+
def aten〇replication_pad3d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
2311+
self_rank, self_dtype = self_rank_dtype
2312+
return self_dtype
2313+
23052314
def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
23062315
return pad_shape_fn(self, pad)
23072316

0 commit comments

Comments
 (0)