Skip to content

Commit 692b2c0

Browse files
authored
Use linalg.index to lower aten.reflection_pad2d (#4105)
"aten.reflection_pad2d" was lowered to linalg using affine maps of the form {indexing_maps = [affine_map<(d0, d1) -> (d0, -d1 + 1)>, affine_map<(d0, d1) -> (d0, d1)>]}. This causes lowering issues in downstream passes such as ""BinaryOpExpr(AffineBinaryOpExpr): Assertion `cast(expr.getRHS()).getValue() > 0 && "nonpositive multiplying coefficient"' failed." Using linalg.index with tensor.extract op instead of the above affine map helps in successful compilation of the same. Signed-off-by: Praveen G <Praveen.G2@amd.com>
1 parent 6daa20e commit 692b2c0

File tree

2 files changed

+73
-29
lines changed

2 files changed

+73
-29
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -434,17 +434,6 @@ class ConvertAtenReflectionPad2dOp
434434
for (auto v : {TOP, BOTTOM})
435435
tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType);
436436

437-
// Helper to reflect/reverse the i-th dimension of an affine map
438-
// without symbols. This only works if applied on a tensor
439-
// for which the corresponding dimension has a statically
440-
// known size which is good enough since we only apply
441-
// it to reflect the padding slices.
442-
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
443-
int64_t size) {
444-
AffineExpr d = map.getResult(i);
445-
return map.replace(d, size - d - 1, numDims, 0);
446-
};
447-
448437
// Create output shape and tensor
449438
SmallVector<Value> resultShape{inputShape};
450439
resultShape[vDim] =
@@ -538,26 +527,41 @@ class ConvertAtenReflectionPad2dOp
538527
Value tile = rewriter.create<tensor::ExtractSliceOp>(
539528
loc, input, extractOffsets, extractShape, allOneStrides);
540529

541-
// Reverse the tile along the horizontal, vertical, or both
542-
// dimensions.
543530
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
544-
if (shouldHReflect(horizontalPos)) {
545-
inputMap =
546-
reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos));
547-
}
548-
if (shouldVReflect(verticalPos)) {
549-
inputMap =
550-
reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos));
551-
}
552531

553-
tile = rewriter
554-
.create<linalg::GenericOp>(
555-
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
556-
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
557-
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
558-
b.create<linalg::YieldOp>(nestedLoc, args[0]);
559-
})
560-
.getResult(0);
532+
tile =
533+
rewriter
534+
.create<linalg::GenericOp>(
535+
loc, llvm::cast<RankedTensorType>(tile.getType()), tile, tile,
536+
ArrayRef({inputMap, idMap}), iteratorTypes,
537+
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
538+
// Use linalg.index to reflect the dims
539+
SmallVector<Value> extractIndices(numDims);
540+
for (unsigned i = 0; i < numDims; i++)
541+
extractIndices[i] =
542+
b.create<linalg::IndexOp>(nestedLoc, i);
543+
544+
auto reflectDim = [&](int64_t padSize, Value dim) {
545+
Value reflectDimSize = getConstant(
546+
rewriter, loc, padSize - 1, rewriter.getIndexType());
547+
return b.create<arith::SubIOp>(loc, reflectDimSize, dim);
548+
};
549+
550+
// Reverse the tile along the horizontal, vertical, or both
551+
// dimensions.
552+
if (shouldHReflect(horizontalPos))
553+
extractIndices[hDim] = reflectDim(
554+
getHPadArgument(horizontalPos), extractIndices[hDim]);
555+
556+
if (shouldVReflect(verticalPos))
557+
extractIndices[vDim] = reflectDim(
558+
getVPadArgument(verticalPos), extractIndices[vDim]);
559+
560+
Value extractValue = rewriter.create<tensor::ExtractOp>(
561+
nestedLoc, tile, extractIndices);
562+
b.create<linalg::YieldOp>(nestedLoc, extractValue);
563+
})
564+
.getResult(0);
561565

562566
// Insert the tile in the resultTensor.
563567
SmallVector<Value> insertOffsets(numDims, zero);

test/Conversion/TorchToLinalg/datamovement.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,43 @@ func.func @torch.aten.permute$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vte
3232
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[],f32>
3333
return %1 : !torch.vtensor<[],f32>
3434
}
35+
36+
// -----
37+
38+
// CHECK: #[[$INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
39+
// CHECK-LABEL: func.func @torch.aten.reflection_pad2d(
40+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,8,9],f32> {
41+
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
42+
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
43+
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
44+
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,4,4],f32> -> tensor<1x1x4x4xf32>
45+
// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<1x1x8x9xf32>
46+
// CHECK: %[[VAL_6:.*]] = linalg.fill ins(%[[VAL_1]] : f32) outs(%[[VAL_5]] : tensor<1x1x8x9xf32>) -> tensor<1x1x8x9xf32>
47+
// CHECK: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_4]][0, 0, 1, 1] [1, 1, 2, 2] [1, 1, 1, 1] : tensor<1x1x4x4xf32> to tensor<1x1x2x2xf32>
48+
// CHECK: %[[VAL_8:.*]] = tensor.extract_slice %[[VAL_4]][0, 0, 1, 1] [1, 1, 2, 2] [1, 1, 1, 1] : tensor<1x1x4x4xf32> to tensor<1x1x2x2xf32>
49+
// CHECK: %[[VAL_9:.*]] = tensor.extract_slice %[[VAL_4]][0, 0, 1, 1] [1, 1, 2, 2] [1, 1, 1, 1] : tensor<1x1x4x4xf32> to tensor<1x1x2x2xf32>
50+
// CHECK: %[[VAL_10:.*]] = linalg.generic {indexing_maps = [#[[$INPUT_MAP]], #[[$INPUT_MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_9]] : tensor<1x1x2x2xf32>) outs(%[[VAL_8]] : tensor<1x1x2x2xf32>) {
51+
// CHECK: ^bb0(%[[VAL_11:.*]]: f32, %[[VAL_12:.*]]: f32):
52+
// CHECK: %[[VAL_13:.*]] = linalg.index 0 : index
53+
// CHECK: %[[VAL_14:.*]] = linalg.index 1 : index
54+
// CHECK: %[[VAL_15:.*]] = linalg.index 2 : index
55+
// CHECK: %[[VAL_16:.*]] = linalg.index 3 : index
56+
// CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_3]], %[[VAL_16]] : index
57+
// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_3]], %[[VAL_15]] : index
58+
// CHECK: %[[VAL_19:.*]] = tensor.extract %[[VAL_7]]{{\[}}%[[VAL_13]], %[[VAL_14]], %[[VAL_18]], %[[VAL_17]]] : tensor<1x1x2x2xf32>
59+
// CHECK: linalg.yield %[[VAL_19]] : f32
60+
// CHECK: } -> tensor<1x1x2x2xf32>
61+
// CHECK: %[[VAL_20:.*]] = tensor.insert_slice %[[VAL_10]] into %[[VAL_6]][0, 0, 0, 0] [1, 1, 2, 2] [1, 1, 1, 1] : tensor<1x1x2x2xf32> into tensor<1x1x8x9xf32>
62+
// CHECK-COUNT-8: linalg.generic
63+
// CHECK: %[[VAL_123:.*]] = tensor.insert_slice
64+
// CHECK: %[[VAL_124:.*]] = torch_c.from_builtin_tensor %[[VAL_123]] : tensor<1x1x8x9xf32> -> !torch.vtensor<[1,1,8,9],f32>
65+
// CHECK: return %[[VAL_124]] : !torch.vtensor<[1,1,8,9],f32>
66+
// CHECK: }
67+
68+
func.func @torch.aten.reflection_pad2d(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,8,9],f32> {
69+
%int2 = torch.constant.int 2
70+
%int3 = torch.constant.int 3
71+
%0 = torch.prim.ListConstruct %int2, %int3, %int2, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
72+
%1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,1,4,4],f32>, !torch.list<int> -> !torch.vtensor<[1,1,8,9],f32>
73+
return %1 : !torch.vtensor<[1,1,8,9],f32>
74+
}

0 commit comments

Comments
 (0)