Skip to content

[TOSA] Handle float<->bool cast via i8 in tosaCastTensorToType #4257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Lallapallooza
Copy link

Add intermediate i8 cast to support float<->bool conversions, which TOSA doesn’t allow directly. Fixes legalization failures for such cases.

Add intermediate i8 cast to support float<->bool conversions, which TOSA
doesn’t allow directly. Fixes legalization failures for such cases.

Signed-off-by: Vitalii Shutov <vitalii.shutov@arm.com>
Copy link
Collaborator

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What legalization failures are now fixed ? Does projects/pt1/e2e_testing/xfail_sets.py need to be updated ?

@Lallapallooza
Copy link
Author

What legalization failures are now fixed ? Does projects/pt1/e2e_testing/xfail_sets.py need to be updated ?

I have added dedicated tests to cover the affected legalization flow. While the existing e2e test TypeConversionI1ToF32Module_basic was already passing before, the generated TOSA IR was not standards-compliant and failed verification under strict profiles.

IR Before

module {
  func.func @main(%arg0: tensor<1x3x1x5xf32>) -> tensor<1x3x1x5xf32> {
    %0 = tosa.cast %arg0 : (tensor<1x3x1x5xf32>) -> tensor<1x3x1x5xi1>
    %1 = tosa.cast %0 : (tensor<1x3x1x5xi1>) -> tensor<1x3x1x5xf32>
    return %1 : tensor<1x3x1x5xf32>
  }
}

Which is failed to verify

$ ./build/bin/mlir-opt --tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" tosa.mlir
tosa.mlir:3:10: error: 'tosa.cast' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,i1), did you mean (i8,i1)? Otherwise, please refer to the 'supported data types' for 'tosa.cast' in the specification.
    %0 = tosa.cast %arg0 : (tensor<1x3x1x5xf32>) -> tensor<1x3x1x5xi1>
         ^
tosa.mlir:3:10: note: see current operation: %0 = "tosa.cast"(%arg0) : (tensor<1x3x1x5xf32>) -> tensor<1x3x1x5xi1>
tosa.mlir:4:10: error: 'tosa.cast' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,f32), did you mean (i1,i8)? Otherwise, please refer to the 'supported data types' for 'tosa.cast' in the specification.
    %1 = tosa.cast %0 : (tensor<1x3x1x5xi1>) -> tensor<1x3x1x5xf32>
         ^
tosa.mlir:4:10: note: see current operation: %1 = "tosa.cast"(%0) : (tensor<1x3x1x5xi1>) -> tensor<1x3x1x5xf32> 

With new patch the generated TOSA is

module {
  func.func @main(%arg0: tensor<1x3x1x5xf32>) -> tensor<1x3x1x5xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32>
    %1 = tosa.equal %arg0, %0 : (tensor<1x3x1x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x1x5xi1>
    %2 = tosa.logical_not %1 : (tensor<1x3x1x5xi1>) -> tensor<1x3x1x5xi1>
    %3 = tosa.cast %2 : (tensor<1x3x1x5xi1>) -> tensor<1x3x1x5xi8>
    %4 = tosa.cast %3 : (tensor<1x3x1x5xi8>) -> tensor<1x3x1x5xf32>
    return %4 : tensor<1x3x1x5xf32>
  }
}

This passes the validator and produces results identical to PyTorch.

@Lallapallooza Lallapallooza requested a review from sjarus July 18, 2025 09:28
Copy link
Collaborator

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants