Skip to content

[TOSA] Handle float<->bool cast via i8 in tosaCastTensorToType #4257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,30 @@ std::optional<Value> tosaCastTensorToType(PatternRewriter &rewriter, Value src,
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
// return std::nullopt;

if (llvm::isa<FloatType>(srcElemTy) && destElemTy.isInteger(1)) {
// TOSA does not support casting from float->i1.
// In PyTorch the bool value will be True if any element is non-zero
Value zeroValue = *getConstTensor<float>(rewriter, op, 0.0f, {}, srcElemTy);
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue)
.failed())
return std::nullopt;

auto cmpTy = srcType.clone(rewriter.getIntegerType(1));
Value isEq =
rewriter.create<tosa::EqualOp>(op->getLoc(), cmpTy, src, zeroValue);
return rewriter.create<tosa::LogicalNotOp>(op->getLoc(),
srcType.clone(destElemTy), isEq);
}

if (srcElemTy.isInteger(1) && llvm::isa<FloatType>(destElemTy)) {
// TOSA does not support casting from i1->float.
// Instead, we cast to i8 and then to the float.
TensorType midType = srcType.clone(rewriter.getIntegerType(8));
Value mid = rewriter.create<tosa::CastOp>(op->getLoc(), midType, src);
return rewriter.create<tosa::CastOp>(op->getLoc(),
srcType.clone(destElemTy), mid);
}

if (srcElemTy == destElemTy)
return src;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ def TypeConversionF32ToF64Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))


class TypeConversionF32ToI1Module(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.float32, True)])
def forward(self, x):
return x.to(torch.bool)


@register_test_case(module_factory=lambda: TypeConversionF32ToI1Module())
def TypeConversionF32ToI1Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5))


class TypeConversionF64ToF32Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
44 changes: 44 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,50 @@ func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !
return %0 : !torch.vtensor<[3,5],si64>
}

// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToBool(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 11
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = torch.constant.none
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
// CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_1]], %[[VAL_7]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xi1>
// CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<3x5xi1>) -> tensor<3x5xi1>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,5],i1>
// CHECK: }
func.func @torch.aten.to.dtype$floatToBool(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],i1> {
%int11 = torch.constant.int 11
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],i1>
return %0 : !torch.vtensor<[3,5],i1>
}

// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype$boolToFloat(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],i1> -> tensor<3x4xi1>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 6
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
// CHECK: %[[VAL_4:.*]] = torch.constant.none
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi1>) -> tensor<3x4xi8>
// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<3x4xi8>) -> tensor<3x4xf32>
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32>
// CHECK: }
func.func @torch.aten.to.dtype$boolToFloat(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],f32> {
%int6 = torch.constant.int 6
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[3,4],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}


// -----
// CHECK-LABEL: func.func @torch.aten.gather(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
Expand Down
Loading