From a856e2d041ba04b68d60ae505284ed8fbc46acb1 Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Fri, 25 Oct 2024 01:37:03 +0530 Subject: [PATCH 1/7] [TorchToLinalg][GridSample] Add support for border padding mode --- .../TorchToLinalg/Uncategorized.cpp | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e89056355785..0ba9cd032c07 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2587,10 +2587,29 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { return res; }; + auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x, + Value SizeSubOne) -> Value { + Value xMaxZero = b.create(loc, x, zeroFloat); + return b.create(loc, xMaxZero, SizeSubOne); + }; + + auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode, + Value x, Value SizeSubOne) -> Value { + Value border = lambdaBorder(b, loc, x, SizeSubOne); + Value zeroInt = + b.create(loc, b.getIntegerAttr(int64type, 0)); + Value isZero = b.create(loc, arith::CmpIPredicate::eq, + paddingMode, zeroInt); + + return b.create(loc, isZero, x, border); + }; + auto resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); Value alignCorners = adaptor.getAlignCorners(); Value interMode = adaptor.getInterpolationMode(); + Value paddingMode = adaptor.getPaddingMode(); + SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) dynamicSizes.push_back(rewriter.create(loc, input, 0)); @@ -2618,10 +2637,14 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value gplus1 = b.create(loc, gr1, oneFloat); Value gPlusMul0 = b.create(loc, gplus0, innerDim0e); Value gPlusMul1 = b.create(loc, gplus1, innerDim1e); - Value result0 = + Value unnorm0 = b.create(loc, gPlusMul0, gr0HalfSelect); - Value result1 = + Value unnorm1 = b.create(loc, gPlusMul1, gr1HalfSelect); + Value result0 = + lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d); + Value result1 = + lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d); Value checkLowerBound0 = b.create( loc, arith::CmpFPredicate::OLT, result0, zeroFloat); Value checkLowerBound1 = b.create( From d9620ff280157cac08c0747a5124bb8cd0dbf507 Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Fri, 25 Oct 2024 21:39:28 +0530 Subject: [PATCH 2/7] [TorchToLinalg][GridSample] Add lit test for border padding --- test/Conversion/TorchToLinalg/gridsampler.mlir | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index 2a291f721fed..f56881898d67 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -96,6 +96,18 @@ func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte // ----- // CHECK-LABEL: func @grid_sampler4 +// CHECK: #map +// CHECK-DAG: %[[Y49:.*]] = arith.maximumf %[[Y47:.*]], %[[CST0:.*]] : f32 +// CHECK-DAG: %[[Y50:.*]] = arith.minimumf %[[Y49:.*]], %[[Y22:.*]] : f32 +// CHECK-DAG: %[[Y51:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[Y52:.*]] = arith.cmpi eq, %[[Y9:.*]], %[[Y51:.*]] : i64 +// CHECK-DAG: %[[Y53:.*]] = arith.select %[[Y52:.*]], %[[Y47:.*]], %[[Y50:.*]] : f32 +// CHECK-DAG: %[[Y54:.*]] = arith.maximumf %[[Y48:.*]], %[[CST0:.*]] : f32 +// CHECK-DAG: %[[Y55:.*]] = arith.minimumf %[[Y54:.*]], %[[Y23:.*]] : f32 +// CHECK-DAG: %[[Y56:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[Y52:.*]] = arith.cmpi eq, %[[Y9:.*]], %[[Y51:.*]] : i64 +// CHECK-DAG: linalg.yield %[[Y60:.*]] : f32 +// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32> func.func @grid_sampler4(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %false = torch.constant.bool 1 %int0 = torch.constant.int 0 From f0801b79de6c8b11f09194b016a368d4d5724468 Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Fri, 25 Oct 2024 21:32:01 +0530 Subject: [PATCH 3/7] [OnnxToTorch][GridSample] Add support for border padding mode --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 542df9ee4c7b..b278fa062d8b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } std::string padding; + int64_t paddingModeInt; if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) return rewriter.notifyMatchFailure(binder.op, "padding_mode bind failure"); - if (padding != "zeros") + if (padding == "zeros") { + paddingModeInt = 0; + } else if (padding == "border") { + paddingModeInt = 1; + } else { return rewriter.notifyMatchFailure( - binder.op, "currently only padding_mode : zeros supported"); + binder.op, + "currently only padding_mode : zeros and border supported"); + } int64_t align; if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, @@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value paddingMode = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + paddingModeInt)); bool alignMode = align; Value alignCorners = rewriter.create( From 3018bb66d9431c710ca0f071b78e844706710428 Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Fri, 25 Oct 2024 11:35:11 +0530 Subject: [PATCH 4/7] [OnnxToTorch][GridSample] Add lit test for border padding mode --- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index ad18724df52a..09b5c7fc4da4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -992,6 +992,18 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t // ----- +// CHECK-LABEL: @test_grid_sampler03 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[B0:.*]] = torch.constant.bool true +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT1]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "border"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_oldest_pad func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} { // CHECK: %[[int0:.*]] = torch.constant.int 0 From 5a71e823df1d312917b007d7abe7d8542892f258 Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Sun, 17 Nov 2024 21:53:29 +0530 Subject: [PATCH 5/7] Simplify paddingMode lowering Evaluate paddingMode at compile time --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 4 +--- .../TorchToLinalg/Uncategorized.cpp | 21 ++++++++++--------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index b278fa062d8b..905846ee86ae 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -163,9 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); Value paddingMode = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - paddingModeInt)); + binder.getLoc(), paddingModeInt); bool alignMode = align; Value alignCorners = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0ba9cd032c07..d5ca980fc7a1 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2593,22 +2593,23 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { return b.create(loc, xMaxZero, SizeSubOne); }; - auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode, + auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode, Value x, Value SizeSubOne) -> Value { - Value border = lambdaBorder(b, loc, x, SizeSubOne); - Value zeroInt = - b.create(loc, b.getIntegerAttr(int64type, 0)); - Value isZero = b.create(loc, arith::CmpIPredicate::eq, - paddingMode, zeroInt); + // Border + if (paddingMode == 1) { + return lambdaBorder(b, loc, x, SizeSubOne); + } - return b.create(loc, isZero, x, border); + return x; }; auto resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); Value alignCorners = adaptor.getAlignCorners(); Value interMode = adaptor.getInterpolationMode(); - Value paddingMode = adaptor.getPaddingMode(); + + int64_t paddingModeInt; + matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt)); SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) @@ -2642,9 +2643,9 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value unnorm1 = b.create(loc, gPlusMul1, gr1HalfSelect); Value result0 = - lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d); + lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d); Value result1 = - lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d); + lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d); Value checkLowerBound0 = b.create( loc, arith::CmpFPredicate::OLT, result0, zeroFloat); Value checkLowerBound1 = b.create( From 523a247602fbec93f2af927b180a7db55fb6dde1 Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Sun, 18 May 2025 12:43:51 +0530 Subject: [PATCH 6/7] Change name convention and handling error for match failure --- lib/Conversion/TorchToLinalg/Uncategorized.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d5ca980fc7a1..386527b9a96a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2588,16 +2588,16 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; auto lambdaBorder = [&](OpBuilder &b, Location loc, Value x, - Value SizeSubOne) -> Value { + Value sizeSubOne) -> Value { Value xMaxZero = b.create(loc, x, zeroFloat); - return b.create(loc, xMaxZero, SizeSubOne); + return b.create(loc, xMaxZero, sizeSubOne); }; auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode, - Value x, Value SizeSubOne) -> Value { + Value x, Value sizeSubOne) -> Value { // Border if (paddingMode == 1) { - return lambdaBorder(b, loc, x, SizeSubOne); + return lambdaBorder(b, loc, x, sizeSubOne); } return x; @@ -2609,7 +2609,10 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value interMode = adaptor.getInterpolationMode(); int64_t paddingModeInt; - matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt)); + if(!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt))) { + return failure(); + } + SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) From 249e5b1effac3eaf07817b0ff0699397a1f07c3a Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Mon, 30 Jun 2025 18:24:17 +0530 Subject: [PATCH 7/7] Fix formatting --- lib/Conversion/TorchToLinalg/Uncategorized.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 386527b9a96a..72705c314f97 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2609,11 +2609,11 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value interMode = adaptor.getInterpolationMode(); int64_t paddingModeInt; - if(!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt))) { + if (!matchPattern(op.getPaddingMode(), + m_TorchConstantInt(&paddingModeInt))) { return failure(); } - SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) dynamicSizes.push_back(rewriter.create(loc, input, 0));