Skip to content

Commit 60ffb91

Browse files
ManewingFlorian Walbroel
andauthored
lib/Dialect/Torch/IR/TorchOps.cpp: fix: use-after-free: erasing an operation during folding (#4274)
This fixes a SEGFAULT in the GreedyPatternRewriteDriver and adds a missing size check to the `torch.aten._assert_tensor_metadata` operation. Erasing an operation during folding is not allowed. Folding the operation may eithermodify it in place or return a set of replacements, but may not erase the operation. (see https://github.com/llvm/llvm-project/blob/e56384ff540e68f9d0500fa27a95354c0730e37b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp#L492-L508) Doing this causes a SEGFAULT (witnessed on macOS Sequoia 15.5, Apple M4): ``` Stack dump: 0. Program arguments: build/bin/torch-mlir-opt -canonicalize --split-input-file -verify-diagnostics test/Dialect/Torch/invalid_canonicalize.mlir #0 0x0000000104091524 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (build/bin/torch-mlir-opt+0x10140d524) #1 0x000000010408fa5c llvm::sys::RunSignalHandlers() (build/bin/torch-mlir-opt+0x10140ba5c) #2 0x0000000104091bc8 SignalHandler(int, __siginfo*, void*) (build/bin/torch-mlir-opt+0x10140dbc8) #3 0x0000000181e10624 (/usr/lib/system/libsystem_platform.dylib+0x1804ac624) #4 0x0000000103c1f7a8 (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist() (build/bin/torch-mlir-opt+0x100f9b7a8) #5 0x0000000103c1cf4c mlir::applyPatternsGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) (build/bin/torch-mlir-opt+0x100f98f4c) #6 0x0000000102c8f62c (anonymous namespace)::Canonicalizer::runOnOperation() (build/bin/torch-mlir-opt+0x10000b62c) #7 0x0000000103c72fa4 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (build/bin/torch-mlir-opt+0x100feefa4) #8 0x0000000103c750d4 mlir::PassManager::run(mlir::Operation*) (build/bin/torch-mlir-opt+0x100ff10d4) #9 0x0000000102c8d774 performActions(llvm::raw_ostream&, std::__1::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) (build/bin/torch-mlir-opt+0x100009774) #10 0x0000000102c8d35c llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_0>(long, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) (build/bin/torch-mlir-opt+0x10000935c) #11 0x000000010403194c mlir::splitAndProcessBuffer(std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef)::$_0::operator()(llvm::StringRef) const (build/bin/torch-mlir-opt+0x1013ad94c) #12 0x00000001040316a4 mlir::splitAndProcessBuffer(std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (build/bin/torch-mlir-opt+0x1013ad6a4) #13 0x0000000102c87078 mlir::MlirOptMain(llvm::raw_ostream&, std::__1::unique_ptr<llvm::MemoryBuffer, std::__1::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (build/bin/torch-mlir-opt+0x100003078) #14 0x0000000102c8731c mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (build/bin/torch-mlir-opt+0x10000331c) #15 0x0000000102c87538 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (build/bin/torch-mlir-opt+0x100003538) #16 0x0000000102c85cd0 main (build/bin/torch-mlir-opt+0x100001cd0) #17 0x0000000181a36b98 build/tools/torch-mlir/test/Dialect/Torch/Output/invalid_canonicalize.mlir.script: line 1: 72586 Segmentation fault: 11 build/bin/torch-mlir-opt -canonicalize --split-input-file -verify-diagnostics test/Dialect/Torch/invalid_canonicalize.mlir ``` Since the `torch.aten._assert_tensor_metadata` operation is only used for static assertion during compile time the folding can be replaced by a canonicalization that checks the assert and then uses a rewriter to erase the operation. The second commit deals with a missing size check in the assert operation before using a zip operation. Without the explicit checkout of the size, the would assert not fail in case the size of the dimensions were the same, but there are either less or more dimensions in the input than specified in the assert. --------- Signed-off-by: Florian Walbroel <walbroel@roofline.ai> Co-authored-by: Florian Walbroel <walbroel@roofline.ai>
1 parent e5221b9 commit 60ffb91

File tree

5 files changed

+104
-37
lines changed

5 files changed

+104
-37
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14374,7 +14374,7 @@ def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata",
1437414374
printDefaultTorchOp(printer, *this, 6, 0);
1437514375
}
1437614376
}];
14377-
let hasFolder = 1;
14377+
let hasCanonicalizer = 1;
1437814378
}
1437914379

1438014380
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5490,48 +5490,64 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
54905490
// Aten_AssertTensorMetadataOp
54915491
//===----------------------------------------------------------------------===//
54925492

5493-
LogicalResult Aten_AssertTensorMetadataOp::fold(
5494-
FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) {
5495-
Value input = getA();
5496-
auto inputType = cast<BaseTensorType>(input.getType());
5497-
if (!inputType.hasDtype() || !inputType.hasSizes())
5498-
return failure();
5493+
namespace {
5494+
class EraseAssertMetadataPattern
5495+
: public OpRewritePattern<Aten_AssertTensorMetadataOp> {
5496+
public:
5497+
using OpRewritePattern<Aten_AssertTensorMetadataOp>::OpRewritePattern;
5498+
5499+
LogicalResult matchAndRewrite(Aten_AssertTensorMetadataOp op,
5500+
PatternRewriter &rewriter) const override {
5501+
Value input = op.getA();
5502+
auto inputType = cast<BaseTensorType>(input.getType());
5503+
if (!inputType.hasDtype() || !inputType.hasSizes())
5504+
return failure();
54995505

5500-
// TODO: Add checks for stride, device, and layout when we can extract that
5501-
// information from the torch tensor. For now, we can only get the shape and
5502-
// dtype info from the tensor hence adding checks for them.
5506+
// TODO: Add checks for stride, device, and layout when we can extract that
5507+
// information from the torch tensor. For now, we can only get the shape and
5508+
// dtype info from the tensor hence adding checks for them.
55035509

5504-
// convert size to a list of integers.
5505-
SmallVector<int64_t> size;
5506-
if (!isa<Torch::NoneType>(getSize().getType())) {
5507-
if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) {
5508-
return emitOpError("expected dtype to be a constant int");
5510+
// convert size to a list of integers.
5511+
SmallVector<int64_t> size;
5512+
if (!isa<Torch::NoneType>(op.getSize().getType())) {
5513+
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) {
5514+
return op.emitOpError("expected dtype to be a constant int");
5515+
}
5516+
if (inputType.getSizes().size() != size.size() ||
5517+
!llvm::all_of(llvm::zip(inputType.getSizes(), size),
5518+
[](const auto &pair) {
5519+
return std::get<0>(pair) == std::get<1>(pair);
5520+
}))
5521+
return op.emitOpError(
5522+
"Failed to canonicalize the _assert_tensor_metadata op since "
5523+
"the sizes do not match");
55095524
}
5510-
if (!llvm::all_of(llvm::zip(inputType.getSizes(), size),
5511-
[](const auto &pair) {
5512-
return std::get<0>(pair) == std::get<1>(pair);
5513-
}))
5514-
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
5515-
"the sizes do not match");
5516-
}
55175525

5518-
// convert dtype to an integer.
5519-
int64_t dtype;
5520-
if (!isa<Torch::NoneType>(getDtype().getType())) {
5521-
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) {
5522-
return emitOpError("expected dtype to be a constant int");
5526+
// convert dtype to an integer.
5527+
int64_t dtype;
5528+
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
5529+
if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtype))) {
5530+
return op.emitOpError("expected dtype to be a constant int");
5531+
}
5532+
FailureOr<Type> inputDtype =
5533+
getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype);
5534+
if (failed(inputDtype))
5535+
return failure();
5536+
if (inputType.getDtype() != inputDtype)
5537+
return op.emitOpError(
5538+
"Failed to canonicalize the _assert_tensor_metadata op since "
5539+
"the dtype does not match");
55235540
}
5524-
FailureOr<Type> inputDtype =
5525-
getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype);
5526-
if (failed(inputDtype))
5527-
return failure();
5528-
if (inputType.getDtype() != inputDtype)
5529-
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
5530-
"the dtype does not match");
5541+
5542+
rewriter.eraseOp(op);
5543+
return success();
55315544
}
5545+
};
5546+
} // namespace
55325547

5533-
getOperation()->erase();
5534-
return success();
5548+
void Aten_AssertTensorMetadataOp::getCanonicalizationPatterns(
5549+
RewritePatternSet &patterns, MLIRContext *context) {
5550+
patterns.add<EraseAssertMetadataPattern>(context);
55355551
}
55365552

55375553
//===----------------------------------------------------------------------===//

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ def emit_with_mutating_variants(key, **kwargs):
10441044
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
10451045
emit(
10461046
"aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()",
1047-
has_folder=True,
1047+
has_canonicalizer=True,
10481048
)
10491049
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
10501050
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")

test/Dialect/Torch/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ func.func @torch.runtime.assert() {
2929
return
3030
}
3131

32+
// CHECK-LABEL: func.func @torch.aten.assert_tensor_metadata
33+
// CHECK-NEXT: return
34+
func.func @torch.aten.assert_tensor_metadata() {
35+
%int4 = torch.constant.int 4
36+
%none = torch.constant.none
37+
%1 = tensor.empty() : tensor<1x1x128x128xi64>
38+
%2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
39+
torch.aten._assert_tensor_metadata %2, %none, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none
40+
return
41+
}
42+
3243
// CHECK-LABEL: func.func @torch.aten.ones_item
3344
// CHECK: %[[CONST:.*]] = torch.constant.int 1
3445
// CHECK: return %[[CONST]] : !torch.int
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: torch-mlir-opt -canonicalize --split-input-file -verify-diagnostics %s
2+
3+
func.func @torch.aten.assert_tensor_metadata_invalid_dtype() {
4+
%int8 = torch.constant.int 8
5+
%none = torch.constant.none
6+
%1 = tensor.empty() : tensor<1x1x128x128xi64>
7+
%2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
8+
// expected-error @+1 {{torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the dtype does not match}}
9+
torch.aten._assert_tensor_metadata %2, %none, %none, %int8, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none
10+
return
11+
}
12+
13+
func.func @torch.aten.assert_tensor_metadata_invalid_size() {
14+
%int0 = torch.constant.int 0
15+
%int2 = torch.constant.int 2
16+
%int3 = torch.constant.int 3
17+
%sizes = torch.prim.ListConstruct %int0, %int2, %int3
18+
: (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
19+
%int4 = torch.constant.int 4
20+
%none = torch.constant.none
21+
%1 = tensor.empty() : tensor<1x1x128x128xi64>
22+
%2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
23+
// expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the sizes do not match}}
24+
torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list<int>, !torch.none, !torch.int, !torch.none, !torch.none
25+
return
26+
}
27+
28+
func.func @torch.aten.assert_tensor_metadata_invalid_size_extra_dim() {
29+
%int1 = torch.constant.int 1
30+
%int4 = torch.constant.int 4
31+
%int128 = torch.constant.int 128
32+
%sizes = torch.prim.ListConstruct %int1, %int1, %int128, %int128, %int4
33+
: (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
34+
%none = torch.constant.none
35+
%1 = tensor.empty() : tensor<1x1x128x128xi64>
36+
%2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64>
37+
// expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the sizes do not match}}
38+
torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list<int>, !torch.none, !torch.int, !torch.none, !torch.none
39+
return
40+
}

0 commit comments

Comments
 (0)