diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 727a4ba5d5e5..bd902d8e2575 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -434,6 +434,30 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) // return std::nullopt; + if (llvm::isa(srcElemTy) && destElemTy.isInteger(1)) { + // TOSA does not support casting from float->i1. + // In PyTorch the bool value will be True if any element is non-zero + Value zeroValue = *getConstTensor(rewriter, op, 0.0f, {}, srcElemTy); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) + return std::nullopt; + + auto cmpTy = srcType.clone(rewriter.getIntegerType(1)); + Value isEq = + rewriter.create(op->getLoc(), cmpTy, src, zeroValue); + return rewriter.create(op->getLoc(), + srcType.clone(destElemTy), isEq); + } + + if (srcElemTy.isInteger(1) && llvm::isa(destElemTy)) { + // TOSA does not support casting from i1->float. + // Instead, we cast to i8 and then to the float. + TensorType midType = srcType.clone(rewriter.getIntegerType(8)); + Value mid = rewriter.create(op->getLoc(), midType, src); + return rewriter.create(op->getLoc(), + srcType.clone(destElemTy), mid); + } + if (srcElemTy == destElemTy) return src; diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index f8deda462905..156ed7959351 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -27,6 +27,21 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) +class TypeConversionF32ToI1Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return x.to(torch.bool) + + +@register_test_case(module_factory=lambda: TypeConversionF32ToI1Module()) +def TypeConversionF32ToI1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) + + class TypeConversionF64ToF32Module(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 3d2e85acee4a..cb1a69e6a622 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1169,6 +1169,50 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> ! return %0 : !torch.vtensor<[3,5],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToBool( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 11 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_1]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xi1> +// CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<3x5xi1>) -> tensor<3x5xi1> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,5],i1> +// CHECK: } +func.func @torch.aten.to.dtype$floatToBool(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> { + %int11 = torch.constant.int 11 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1> + return %0 : !torch.vtensor<[3,5],i1> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype$boolToFloat( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],i1> -> tensor<3x4xi1> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 6 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi1>) -> tensor<3x4xi8> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x4xi8>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.to.dtype$boolToFloat(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> { + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[3,4],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + + // ----- // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,