From 1649e9ac77b0e726b614cae900517493673ee5a9 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Mon, 7 Jul 2025 18:11:41 +0100 Subject: [PATCH 1/2] [TOSA] Handle float<->bool cast via i8 in tosaCastTensorToType MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add intermediate i8 cast to support float<->bool conversions, which TOSA doesn’t allow directly. Fixes legalization failures for such cases. Signed-off-by: Vitalii Shutov --- lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 9c1d6d3d0d37..7d858517939c 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -433,6 +433,16 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) // return std::nullopt; + // Check if the source and destination types are boolean or floating-point + if ((srcElemTy.isInteger(1) && llvm::isa(destElemTy)) || + (llvm::isa(srcElemTy) && destElemTy.isInteger(1))) { + // TOSA does not support casting between float<->i8. + // Instead, we cast to i8 and then to the destination type. + TensorType midType = srcType.clone(rewriter.getIntegerType(8)); + Value mid = rewriter.create(op->getLoc(), midType, src); + return tosaCastTensorToType(rewriter, mid, destType); // recurse once + } + if (srcElemTy == destElemTy) return src; From b80a57c8f40b0323a59320944a6931a48f9d21aa Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Thu, 17 Jul 2025 18:47:27 +0100 Subject: [PATCH 2/2] fix the float<->i1 flow and add tests --- .../TorchToTosa/TosaLegalizeUtils.cpp | 26 ++++++++--- .../test_suite/type_conversion.py | 15 +++++++ test/Conversion/TorchToTosa/basic.mlir | 44 +++++++++++++++++++ 3 files changed, 79 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 848f4ee94bdf..bd902d8e2575 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -434,14 +434,28 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) // return std::nullopt; - // Check if the source and destination types are boolean or floating-point - if ((srcElemTy.isInteger(1) && llvm::isa(destElemTy)) || - (llvm::isa(srcElemTy) && destElemTy.isInteger(1))) { - // TOSA does not support casting between float<->i8. - // Instead, we cast to i8 and then to the destination type. + if (llvm::isa(srcElemTy) && destElemTy.isInteger(1)) { + // TOSA does not support casting from float->i1. + // In PyTorch the bool value will be True if any element is non-zero + Value zeroValue = *getConstTensor(rewriter, op, 0.0f, {}, srcElemTy); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) + return std::nullopt; + + auto cmpTy = srcType.clone(rewriter.getIntegerType(1)); + Value isEq = + rewriter.create(op->getLoc(), cmpTy, src, zeroValue); + return rewriter.create(op->getLoc(), + srcType.clone(destElemTy), isEq); + } + + if (srcElemTy.isInteger(1) && llvm::isa(destElemTy)) { + // TOSA does not support casting from i1->float. + // Instead, we cast to i8 and then to the float. TensorType midType = srcType.clone(rewriter.getIntegerType(8)); Value mid = rewriter.create(op->getLoc(), midType, src); - return tosaCastTensorToType(rewriter, mid, destType); // recurse once + return rewriter.create(op->getLoc(), + srcType.clone(destElemTy), mid); } if (srcElemTy == destElemTy) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index f8deda462905..156ed7959351 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -27,6 +27,21 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) +class TypeConversionF32ToI1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return x.to(torch.bool) + + +@register_test_case(module_factory=lambda: TypeConversionF32ToI1Module()) +def TypeConversionF32ToI1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + + class TypeConversionF64ToF32Module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 3d2e85acee4a..cb1a69e6a622 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1169,6 +1169,50 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> ! return %0 : !torch.vtensor<[3,5],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToBool( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 11 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_1]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xi1> +// CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<3x5xi1>) -> tensor<3x5xi1> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,5],i1> +// CHECK: } +func.func @torch.aten.to.dtype$floatToBool(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> { + %int11 = torch.constant.int 11 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1> + return %0 : !torch.vtensor<[3,5],i1> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype$boolToFloat( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],i1> -> tensor<3x4xi1> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 6 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi1>) -> tensor<3x4xi8> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x4xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.to.dtype$boolToFloat(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> { + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[3,4],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + + // ----- // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,