diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 76b9b87cbfe9..db6648de2db1 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -16,11 +16,13 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::TorchConversion; namespace { @@ -138,6 +140,97 @@ class ConvertAtenTensorOpPattern : public OpConversionPattern { } }; +class ConvertAtenAsStridedOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenAsStridedOp::Adaptor; + LogicalResult + matchAndRewrite(AtenAsStridedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // In some cases AtenAsStridedOp is equivalent to a Tensor ExtractSliceOp. + // We will try to match those cases here. + auto inputShape = + cast(adaptor.getSelf().getType()).getShape(); + auto outputShape = + cast(op.getResult().getType()).getSizes(); + auto resultTy = + cast(getTypeConverter()->convertType(op.getType())); + + // If the output shape is strictly larger than the input shape at any + // dimension than this AtenAsStridedOp is not equivalent to a slice. + for (uint64_t i = 0; i < outputShape.size(); ++i) { + if (outputShape[i] > inputShape[i]) + return failure(); + } + + // Calculate what the strides attribute should be if the input tensor is + // contiguous. + SmallVector contiguousStrides(inputShape.size(), 1); + for (int i = inputShape.size() - 2; i >= 0; --i) { + contiguousStrides[i] = contiguousStrides[i + 1] * inputShape[i + 1]; + } + + SmallVector outSizeValues, opStridesValues; + if (!getListConstructElements(adaptor.getStride(), opStridesValues)) + return op.emitError( + "unimplemented: the tensor list is not from list construct"); + + if (!getListConstructElements(adaptor.getSize(), outSizeValues)) + return op.emitError( + "unimplemented: the tensor list is not from list construct"); + + // Get storage offset + int64_t offset; + if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset))) + offset = 0; + + APInt size; + SmallVector outSize(inputShape.size(), 0); + for (uint64_t i = 0; i < outSizeValues.size(); ++i) { + if (!matchPattern(outSizeValues[i], m_Op( + m_ConstantInt(&size))) || + !size.isSignedIntN(64)) + return failure(); + outSize[i] = size.getSExtValue(); + } + APInt stride; + SmallVector opStrides(inputShape.size(), 0); + for (uint64_t i = 0; i < opStridesValues.size(); ++i) { + if (!matchPattern(opStridesValues[i], m_Op( + m_ConstantInt(&stride))) || + !stride.isSignedIntN(64)) + return failure(); + opStrides[i] = stride.getSExtValue(); + } + + // Slice dims are the dims where the input and output shapes are not equal. + SmallVector sliceDims; + for (uint64_t i = 0; i < inputShape.size(); ++i) { + if (outSize[i] != inputShape[i]) + sliceDims.push_back(i); + } + + // If there are no slice dims, then the AtenAsStridedOp is equivalent to the + // input tensor. + if (sliceDims.empty()) { + rewriter.replaceOp(op, adaptor.getSelf()); + return success(); + } + + SmallVector sliceOffsets(inputShape.size(), 0); + SmallVector sliceStrides(opStrides.size(), 1); + for (auto dim : sliceDims) { + sliceOffsets[dim] = offset / contiguousStrides[dim]; + sliceStrides[dim] = opStrides[dim] / contiguousStrides[dim]; + } + + rewriter.replaceOpWithNewOp( + op, resultTy, adaptor.getSelf(), ValueRange(), ValueRange(), + ValueRange(), sliceOffsets, outSize, sliceStrides); + return success(); + } +}; + class ConvertTorchToTensor : public ConvertTorchToTensorBase { public: @@ -153,6 +246,7 @@ class ConvertTorchToTensor target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); @@ -160,7 +254,8 @@ class ConvertTorchToTensor RewritePatternSet patterns(context); patterns.add(typeConverter, context); + ConvertAtenTensorOpPattern, ConvertAtenAsStridedOp>( + typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 1ad698db9cc1..7e2fccca01b8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6730,3 +6730,26 @@ def forward(self, x): @register_test_case(module_factory=lambda: Aten_AssertScalar()) def Aten_AssertScalar_basic(module, tu: TestUtils): module.forward(torch.tensor(4)) + + +# ============================================================================== + + +class AtenAsStridedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.as_strided(x, (2, 2), (3, 3), 1) + + +@register_test_case(module_factory=lambda: AtenAsStridedModule()) +def AsStridedModule_basic(module, tu: TestUtils): + module.forward(torch.randn(25, 1, 1)) diff --git a/test/Conversion/TorchToTensor/torch_to_tensor.mlir b/test/Conversion/TorchToTensor/torch_to_tensor.mlir index 277dabc3b891..73a62cc4e765 100644 --- a/test/Conversion/TorchToTensor/torch_to_tensor.mlir +++ b/test/Conversion/TorchToTensor/torch_to_tensor.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s +// RUN: torch-mlir-opt <%s -split-input-file -convert-torch-to-tensor | FileCheck %s // CHECK-LABEL: func.func @test_shape func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> { @@ -6,3 +6,28 @@ func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3], %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64> return %0 : !torch.vtensor<[3],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_as_strided +func.func @test_as_strided(%arg0: !torch.vtensor<[1,128,1024,192],f32>) -> !torch.vtensor<[1,128,1024,128],f32> { + %c0_i64 = arith.constant 0 : i64 + %int0 = torch_c.from_i64 %c0_i64 + %c1_i64 = arith.constant 1 : i64 + %int1 = torch_c.from_i64 %c1_i64 + %c128_i64 = arith.constant 128 : i64 + %int128 = torch_c.from_i64 %c128_i64 + %c192_i64 = arith.constant 192 : i64 + %int192 = torch_c.from_i64 %c192_i64 + %c1024_i64 = arith.constant 1024 : i64 + %int1024 = torch_c.from_i64 %c1024_i64 + %c24576_i64 = arith.constant 24576 : i64 + %int24576 = torch_c.from_i64 %c24576_i64 + %c25165824_i64 = arith.constant 25165824 : i64 + %int25165824 = torch_c.from_i64 %c25165824_i64 + %0 = torch.prim.ListConstruct %int1, %int128, %int1024, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int25165824, %int192, %int24576, %int1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.+]] = tensor.extract_slice %0[0, 0, 0, 0] [1, 128, 1024, 128] [1, 1, 1, 1] : tensor<1x128x1024x192xf32> to tensor<1x128x1024x128xf32> + %2 = torch.aten.as_strided %arg0, %0, %1, %int0 : !torch.vtensor<[1,128,1024,192],f32>, !torch.list, !torch.list, !torch.int -> !torch.vtensor<[1,128,1024,128],f32> + return %2 : !torch.vtensor<[1,128,1024,128],f32> +}