Skip to content

Commit 46c3888

Browse files
authored
[TorchToLinalg] Support lowering AtenReplicationPad3d to linalg (#4233)
Add support of AtenReplicationPad3d in torch dialect and lowering it to linalg backend AtenReplicationPad3d is lowered using a sequence of tensor.extract_slice and tensor.concat operations consistent with the existing lowerings of AtenReplicationPad1d and AtenReplicationPad2d for the linalg backend
1 parent 8366790 commit 46c3888

File tree

9 files changed

+259
-1
lines changed

9 files changed

+259
-1
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10316,6 +10316,30 @@ def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [
1031610316
}];
1031710317
}
1031810318

10319+
def Torch_AtenReplicationPad3dOp : Torch_Op<"aten.replication_pad3d", [
10320+
AllowsTypeRefinement,
10321+
HasValueSemantics,
10322+
ReadOnly
10323+
]> {
10324+
let summary = "Generated op for `aten::replication_pad3d : (Tensor, int[]) -> (Tensor)`";
10325+
let arguments = (ins
10326+
AnyTorchTensorType:$self,
10327+
AnyTorchListOfTorchIntType:$padding
10328+
);
10329+
let results = (outs
10330+
AnyTorchOptionalTensorType:$result
10331+
);
10332+
let hasCustomAssemblyFormat = 1;
10333+
let extraClassDefinition = [{
10334+
ParseResult AtenReplicationPad3dOp::parse(OpAsmParser &parser, OperationState &result) {
10335+
return parseDefaultTorchOp(parser, result, 2, 1);
10336+
}
10337+
void AtenReplicationPad3dOp::print(OpAsmPrinter &printer) {
10338+
printDefaultTorchOp(printer, *this, 2, 1);
10339+
}
10340+
}];
10341+
}
10342+
1031910343
def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
1032010344
AllowsTypeRefinement,
1032110345
HasValueSemantics,

lib/Conversion/TorchToLinalg/TensorConstructors.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,107 @@ class ConvertAtenReplicationPad2dOp
426426
};
427427
} // namespace
428428

429+
namespace {
430+
431+
// Lower aten.replication_pad3d operator into a sequence of
432+
// tensor.extract_slice and tensor.concat operations.
433+
class ConvertAtenReplicationPad3dOp
434+
: public OpConversionPattern<AtenReplicationPad3dOp> {
435+
436+
private:
437+
enum sliceLoc { START = 0, END = 1 };
438+
439+
Value extractSlice(ConversionPatternRewriter &rewriter, Location loc,
440+
Value input, int64_t dimension, sliceLoc sliceLoc) const {
441+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
442+
int64_t inputRank = inputType.getRank();
443+
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
444+
445+
SmallVector<OpFoldResult> offsets(inputRank, rewriter.getIndexAttr(0));
446+
if (sliceLoc == END) {
447+
Value dimSize = inputShape[dimension];
448+
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
449+
Value endIdx = rewriter.create<arith::SubIOp>(loc, dimSize, one);
450+
offsets[dimension] = getAsOpFoldResult(endIdx);
451+
}
452+
453+
SmallVector<OpFoldResult> allOneStrides(inputRank,
454+
rewriter.getIndexAttr(1));
455+
SmallVector<OpFoldResult> sizes(inputRank, rewriter.getIndexAttr(0));
456+
for (int i = 0; i < inputRank; ++i)
457+
sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1)
458+
: getAsOpFoldResult(inputShape[i]);
459+
460+
Value extractedSlice = rewriter.create<tensor::ExtractSliceOp>(
461+
loc, input, offsets, sizes, allOneStrides);
462+
return extractedSlice;
463+
}
464+
465+
Value createTile(ConversionPatternRewriter &rewriter, Location loc,
466+
Value slice, int64_t tileWidth, int64_t dimension) const {
467+
SmallVector<Value> slices(tileWidth, slice);
468+
if (tileWidth == 1)
469+
return slice;
470+
return rewriter.create<tensor::ConcatOp>(loc, dimension, slices);
471+
}
472+
473+
public:
474+
using OpConversionPattern::OpConversionPattern;
475+
476+
LogicalResult
477+
matchAndRewrite(AtenReplicationPad3dOp op, OpAdaptor adaptor,
478+
ConversionPatternRewriter &rewriter) const override {
479+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
480+
return failure();
481+
482+
Location loc = op->getLoc();
483+
Value input = adaptor.getSelf();
484+
auto inputType = llvm::cast<RankedTensorType>(input.getType());
485+
int64_t inputRank = inputType.getRank();
486+
unsigned numDims = inputType.getRank();
487+
assert(numDims >= 2 && "Not enough input dimensions");
488+
489+
SmallVector<int64_t> padInts;
490+
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
491+
return rewriter.notifyMatchFailure(
492+
op, "only support constant int pad ranges");
493+
494+
if (padInts.size() != 6)
495+
return rewriter.notifyMatchFailure(
496+
op, "pad range must have exactly six values");
497+
498+
Value res = input;
499+
int64_t padIdx = 0;
500+
for (int64_t dim = inputRank - 1; dim >= inputRank - 3; dim--) {
501+
int64_t startTileWidth = padInts[padIdx++];
502+
int64_t endTileWidth = padInts[padIdx++];
503+
504+
SmallVector<Value> resultParts;
505+
if (startTileWidth > 0) {
506+
Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::START);
507+
Value tile = createTile(rewriter, loc, slice, startTileWidth, dim);
508+
resultParts.push_back(tile);
509+
}
510+
511+
resultParts.push_back(res);
512+
513+
if (endTileWidth > 0) {
514+
Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::END);
515+
Value tile = createTile(rewriter, loc, slice, endTileWidth, dim);
516+
resultParts.push_back(tile);
517+
}
518+
519+
if (resultParts.size() > 1)
520+
res = rewriter.create<tensor::ConcatOp>(loc, dim, resultParts);
521+
}
522+
523+
Type resultType = getTypeConverter()->convertType(op.getType());
524+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, res);
525+
return success();
526+
}
527+
};
528+
529+
} // namespace
429530
namespace {
430531
// Converts constant tensor allocation like ops.
431532
template <typename OpTy, int fillVal>
@@ -696,6 +797,8 @@ void mlir::torch::torch_to_linalg::
696797
RewritePatternSet &patterns,
697798
ConversionTarget &target) {
698799
MLIRContext *context = patterns.getContext();
800+
target.addIllegalOp<AtenReplicationPad3dOp>();
801+
patterns.add<ConvertAtenReplicationPad3dOp>(typeConverter, context);
699802
target.addIllegalOp<AtenReplicationPad2dOp>();
700803
patterns.add<ConvertAtenReplicationPad2dOp>(typeConverter, context);
701804
target.addIllegalOp<AtenReplicationPad1dOp>();

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10921,6 +10921,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1092110921
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
1092210922
" return %4 : !torch.list<int>\n"
1092310923
" }\n"
10924+
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
10925+
" %false = torch.constant.bool false\n"
10926+
" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n"
10927+
" %none = torch.constant.none\n"
10928+
" %str_0 = torch.constant.str \"AssertionError: \"\n"
10929+
" %int3 = torch.constant.int 3\n"
10930+
" %int6 = torch.constant.int 6\n"
10931+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
10932+
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
10933+
" torch.prim.If %1 -> () {\n"
10934+
" torch.prim.If.yield\n"
10935+
" } else {\n"
10936+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
10937+
" torch.prim.If.yield\n"
10938+
" }\n"
10939+
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
10940+
" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n"
10941+
" torch.prim.If %3 -> () {\n"
10942+
" torch.prim.If.yield\n"
10943+
" } else {\n"
10944+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10945+
" torch.prim.If.yield\n"
10946+
" }\n"
10947+
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
10948+
" return %4 : !torch.list<int>\n"
10949+
" }\n"
1092410950
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
1092510951
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1092610952
" return %0#1 : !torch.int\n"
@@ -10929,6 +10955,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1092910955
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1093010956
" return %0#1 : !torch.int\n"
1093110957
" }\n"
10958+
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
10959+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
10960+
" return %0#1 : !torch.int\n"
10961+
" }\n"
1093210962
" func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
1093310963
" %false = torch.constant.bool false\n"
1093410964
" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8536,9 +8536,13 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
85368536
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
85378537
op, op.getType(), op.getSelf(), usefulPads);
85388538
break;
8539+
case 3:
8540+
rewriter.replaceOpWithNewOp<AtenReplicationPad3dOp>(
8541+
op, op.getType(), op.getSelf(), usefulPads);
8542+
break;
85398543
default:
85408544
return rewriter.notifyMatchFailure(
8541-
op, "unsupported number of dims for 'reflect' mode: " +
8545+
op, "unsupported number of dims for 'replicate' mode: " +
85428546
std::to_string(numPadDims));
85438547
}
85448548
return success();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,8 @@
859859
"ReplicationPad2dModule_left0",
860860
"ReplicationPad2dModule_right0",
861861
"ReplicationPad2dModule_top0",
862+
"ReplicationPad3dModule_basic",
863+
"ReplicationPad3dModuleSingleIntPad_basic",
862864
"ScalarImplicitFloatModule_basic",
863865
# REMOVE WHEN ENABLE_GQA IS ADDED
864866
"ScatterAddDynamicModule_basic",
@@ -3954,6 +3956,8 @@
39543956
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
39553957
"ReplicationPad1dModule_2DInput_basic",
39563958
"ReplicationPad1dModule_3DInput_basic",
3959+
"ReplicationPad3dModule_basic",
3960+
"ReplicationPad3dModuleSingleIntPad_basic",
39573961
}
39583962

39593963
ONNX_TOSA_CRASHING_SET = {
@@ -4804,6 +4808,8 @@
48044808
"RMSNormDynamicModule_basic",
48054809
"ReplicationPad1dModule_2DInput_basic",
48064810
"ReplicationPad1dModule_3DInput_basic",
4811+
"ReplicationPad3dModule_basic",
4812+
"ReplicationPad3dModuleSingleIntPad_basic",
48074813
"RollModule_basic",
48084814
"RsubIntModule_noalpha_basic",
48094815
"ScalarConstantTupleModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,11 @@ def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> Lis
22812281
assert len(padding) == 4, 'padding size expected to be 4'
22822282
return pad_shape_fn(self, padding)
22832283

2284+
def aten〇replication_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]:
2285+
assert len(self) >= 3
2286+
assert len(padding) == 6, 'padding size expected to be 6'
2287+
return pad_shape_fn(self, padding)
2288+
22842289
def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
22852290
self_rank, self_dtype = self_rank_dtype
22862291
return self_dtype
@@ -2289,6 +2294,10 @@ def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding:
22892294
self_rank, self_dtype = self_rank_dtype
22902295
return self_dtype
22912296

2297+
def aten〇replication_pad3d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
2298+
self_rank, self_dtype = self_rank_dtype
2299+
return self_dtype
2300+
22922301
def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
22932302
return pad_shape_fn(self, pad)
22942303

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ def emit_with_mutating_variants(key, **kwargs):
811811
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
812812
emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)")
813813
emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)")
814+
emit("aten::replication_pad3d : (Tensor, int[]) -> (Tensor)")
814815
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
815816
emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)")
816817
emit("aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,53 @@ def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils):
5959
# ==============================================================================
6060

6161

62+
class ReplicationPad3dModule(torch.nn.Module):
63+
def __init__(self):
64+
super().__init__()
65+
66+
@export
67+
@annotate_args(
68+
[
69+
None,
70+
([-1, -1, -1, -1, -1], torch.float32, True),
71+
]
72+
)
73+
def forward(self, x):
74+
return torch.ops.aten.replication_pad3d(x, [3, 5, 7, 0, 1, 2])
75+
76+
77+
@register_test_case(module_factory=lambda: ReplicationPad3dModule())
78+
def ReplicationPad3dModule_basic(module, tu: TestUtils):
79+
module.forward(tu.rand(1, 15, 20, 1, 10, low=-1))
80+
81+
82+
# ==============================================================================
83+
84+
85+
class ReplicationPad3dModuleSingleIntPad(torch.nn.Module):
86+
def __init__(self):
87+
super().__init__()
88+
self.pad = torch.nn.ReplicationPad3d(3)
89+
90+
@export
91+
@annotate_args(
92+
[
93+
None,
94+
([-1, -1, -1, -1, -1], torch.float32, True),
95+
]
96+
)
97+
def forward(self, x):
98+
return self.pad(x)
99+
100+
101+
@register_test_case(module_factory=lambda: ReplicationPad3dModuleSingleIntPad())
102+
def ReplicationPad3dModuleSingleIntPad_basic(module, tu: TestUtils):
103+
module.forward(tu.rand(1, 15, 20, 1, 10, low=-1))
104+
105+
106+
# ==============================================================================
107+
108+
62109
class ReflectionPad2dModule(torch.nn.Module):
63110
def __init__(self):
64111
super().__init__()

test/Conversion/TorchToLinalg/basic.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,37 @@ func.func @test_rotary_embedding(%arg0: !torch.vtensor<[1,3,2,6],f32>, %arg1: !t
425425
%4 = torch.onnx.rotary_embedding %arg0, %arg1, %arg2, %arg3, %int0, %int0_0, %int0_1, %int0_2, %float1.000000e00 : !torch.vtensor<[1,3,2,6],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[4,3],f32>, !torch.vtensor<[4,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.float -> !torch.vtensor<[1,3,2,6],f32>
426426
return %4 : !torch.vtensor<[1,3,2,6],f32>
427427
}
428+
429+
// -----
430+
431+
// CHECK-LABEL: func.func @torch.ops.aten.replication_pad3d$basic(
432+
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3,5],f32>) -> !torch.vtensor<[7,7,6],f32>
433+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,3,5],f32> -> tensor<4x3x5xf32>
434+
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
435+
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
436+
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
437+
// CHECK: %[[IDX5:.*]] = arith.constant 5 : index
438+
// CHECK: %[[IDX1:.*]] = arith.constant 1 : index
439+
// CHECK: %[[SUB2:.*]] = arith.subi %[[IDX5]], %[[IDX1]] : index
440+
// CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[T0]][0, 0, %[[SUB2]]] [4, 3, 1] [1, 1, 1] : tensor<4x3x5xf32> to tensor<4x3x1xf32>
441+
// CHECK: %[[CONCAT1:.*]] = tensor.concat dim(2) %[[T0]], %[[SLICE1]] : (tensor<4x3x5xf32>, tensor<4x3x1xf32>) -> tensor<4x3x6xf32>
442+
// CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[CONCAT1]][0, 0, 0] [4, 1, 6] [1, 1, 1] : tensor<4x3x6xf32> to tensor<4x1x6xf32>
443+
// CHECK: %[[CONCAT2:.*]] = tensor.concat dim(1) %[[SLICE2]], %[[SLICE2]], %[[SLICE2]] : (tensor<4x1x6xf32>, tensor<4x1x6xf32>, tensor<4x1x6xf32>) -> tensor<4x3x6xf32>
444+
// CHECK: %[[SUB3:.*]] = arith.subi {{.*}}, {{.*}} : index
445+
// CHECK: %[[SLICE3:.*]] = tensor.extract_slice %[[CONCAT1]][0, %[[SUB3]], 0] [4, 1, 6] [1, 1, 1] : tensor<4x3x6xf32> to tensor<4x1x6xf32>
446+
// CHECK: %[[CONCAT3:.*]] = tensor.concat dim(1) %[[CONCAT2]], %[[CONCAT1]], %[[SLICE3]] : (tensor<4x3x6xf32>, tensor<4x3x6xf32>, tensor<4x1x6xf32>) -> tensor<4x7x6xf32>
447+
// CHECK: %[[SUB4:.*]] = arith.subi {{.*}}, {{.*}} : index
448+
// CHECK: %[[SLICE4:.*]] = tensor.extract_slice %[[CONCAT3]][%[[SUB4]], 0, 0] [1, 7, 6] [1, 1, 1] : tensor<4x7x6xf32> to tensor<1x7x6xf32>
449+
// CHECK: %[[CONCAT4:.*]] = tensor.concat dim(0) %[[SLICE4]], %[[SLICE4]], %[[SLICE4]] : (tensor<1x7x6xf32>, tensor<1x7x6xf32>, tensor<1x7x6xf32>) -> tensor<3x7x6xf32>
450+
// CHECK: %[[CONCAT5:.*]] = tensor.concat dim(0) %[[CONCAT3]], %[[CONCAT4]] : (tensor<4x7x6xf32>, tensor<3x7x6xf32>) -> tensor<7x7x6xf32>
451+
// CHECK: %[[CAST:.*]] = tensor.cast %[[CONCAT5]] : tensor<7x7x6xf32> to tensor<7x7x6xf32>
452+
// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<7x7x6xf32> -> !torch.vtensor<[7,7,6],f32>
453+
// CHECK: return %[[OUT]] : !torch.vtensor<[7,7,6],f32>
454+
func.func @torch.ops.aten.replication_pad3d$basic(%arg0: !torch.vtensor<[4,3,5],f32>) -> !torch.vtensor<[7,7,6],f32> {
455+
%c0 = torch.constant.int 0
456+
%c1 = torch.constant.int 1
457+
%c3 = torch.constant.int 3
458+
%padding = torch.prim.ListConstruct %c0, %c1, %c3, %c1, %c0, %c3 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
459+
%0 = torch.aten.replication_pad3d %arg0, %padding : !torch.vtensor<[4,3,5],f32>, !torch.list<int> -> !torch.vtensor<[7,7,6],f32>
460+
return %0 : !torch.vtensor<[7,7,6],f32>
461+
}

0 commit comments

Comments
 (0)