Skip to content

Failed to legalize operation 'torch.aten.scaled_dot_product_attention' (Self attention torch -> tosa conversion) #4279

@HemKava

Description

@HemKava

I am running Llama'3 attention layer's Torch mlir to TOSA mlir conversion pipeline (command below), but seeing torch.aten.scaled_dot_product_attention as illegal function. Can someone help me figure out a pass for tosa conversion correctly?

Batch Size : 1
Seq Length : 12
Hidden Size: 4096
Dimension (Head) 128
Dimension (Model) 4096

With this, I generated Torch MLIR for self attention layer with following state vectors:
Hidden State, Shape: torch.Size([1, 12, 4096]), Dtype: torch.bfloat16
Position Embeddings - Cos, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16
Position Embeddings - Sin, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16
Attention Mask, Shape: torch.Size([1, 1, 1, 12]), Dtype: torch.bfloat16

Now, I am converting TorchMLIR to TOSA MLIR, like below:

torch-mlir-opt --torch-decompose-complex-ops --torch-function-to-torch-backend-pipeline --torch-backend-to-tosa-backend-pipeline layer0_self_attn.torch.mlir -o layer0_self_attn.tosa.mlir

Following is the error I get when converting scaled_dot_product_attention:

layer0_self_attn.torch.mlir:66:11: error: failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal
%47 = torch.aten.scaled_dot_product_attention %28, %41, %46, %arg3, %float0.000000e00, %false, %float8.838830e-02, %false : !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[1,32,12,128],bf16>
^
layer0_self_attn.torch.mlir:66:11: note: see current operation: %164 = "torch.aten.scaled_dot_product_attention"(%112, %154, %163, %arg3, %6, %8, %5, %8) : (!torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool) -> !torch.vtensor<[1,32,12,128],bf16>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions