-
Notifications
You must be signed in to change notification settings - Fork 618
Description
I am running Llama'3 attention layer's Torch mlir to TOSA mlir conversion pipeline (command below), but seeing torch.aten.scaled_dot_product_attention as illegal function. Can someone help me figure out a pass for tosa conversion correctly?
Batch Size : 1
Seq Length : 12
Hidden Size: 4096
Dimension (Head) 128
Dimension (Model) 4096
With this, I generated Torch MLIR for self attention layer with following state vectors:
Hidden State, Shape: torch.Size([1, 12, 4096]), Dtype: torch.bfloat16
Position Embeddings - Cos, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16
Position Embeddings - Sin, Shape: torch.Size([1, 12, 128]), Dtype: torch.bfloat16
Attention Mask, Shape: torch.Size([1, 1, 1, 12]), Dtype: torch.bfloat16
Now, I am converting TorchMLIR to TOSA MLIR, like below:
torch-mlir-opt --torch-decompose-complex-ops --torch-function-to-torch-backend-pipeline --torch-backend-to-tosa-backend-pipeline layer0_self_attn.torch.mlir -o layer0_self_attn.tosa.mlir
Following is the error I get when converting scaled_dot_product_attention
:
layer0_self_attn.torch.mlir:66:11: error: failed to legalize operation 'torch.aten.scaled_dot_product_attention' that was explicitly marked illegal
%47 = torch.aten.scaled_dot_product_attention %28, %41, %46, %arg3, %float0.000000e00, %false, %float8.838830e-02, %false : !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[1,32,12,128],bf16>
^
layer0_self_attn.torch.mlir:66:11: note: see current operation: %164 = "torch.aten.scaled_dot_product_attention"(%112, %154, %163, %arg3, %6, %8, %5, %8) : (!torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,32,12,128],bf16>, !torch.vtensor<[1,1,1,12],bf16>, !torch.float, !torch.bool, !torch.float, !torch.bool) -> !torch.vtensor<[1,32,12,128],bf16>