-
Notifications
You must be signed in to change notification settings - Fork 618
Description
This code snippet of pixel_unshuffle.mlir that will reproduce the issue:
module {
func.func @main(%arg0: !torch.vtensor<[1,8,9],f32>) -> !torch.vtensor<[1,8,9],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int1 = torch.constant.int 1
%0 = torch.operator "torch.aten.pixel_unshuffle"(%arg0, %int1) : (!torch.vtensor<[1,8,9],f32>, !torch.int) -> !torch.vtensor<[1,8,9],f32>
return %0 : !torch.vtensor<[1,8,9],f32>
}
}
By executing the following command,torch.aten.pixel_unshuffle op will not be lowered:
$torch-mlir-opt --convert-torch-to-linalg pixel_unshuffle.mlir
Proposal:
This aten.pixel_unshuffle operation can be supported by providing a torch dialect-level decomposition similar to the pixel_shuffle op.
The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and you can find PyTorch implementation in main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp
Note that the aten.pixel_unshuffle operator decomposes into prims.split_dim, aten.permute, and prims.collapse operations.
For example, an input tensor of shape <1x8x4x4xf32> with down-sampling factor of 2, pixel_unshuffle op could be lowered following these steps:
Input: tensor<1x8x4x4xf32>, Factor: 2
Output: tensor<1x32x2x2xf32>
//Step 1: reshape: [1,8,4,4] -> [1,8,2,2,2,2]
%reshaped1 = tensor.reshape %input : tensor<1x8x4x4xf32> into tensor<1x8x2x2x2x2xf32>
//Step 2: Permute dims to move 2x2 to channels: [1,8,2,2,2,2] -> [1,8,2,2,2,2]
%transposed = linalg.transpose %reshaped1 [0,1,4,2,5,3] : tensor<1x8x2x2x2x2xf32> to tensor<1x8x2x2x2x2xf32>
// Step 3: Final reshape to [1,32,2,2]
%result = tensor.reshape %transposed : tensor<1x8x2x2x2x2xf32> into tensor<1x32x2x2xf32>
References:
PyTorch Pixel_Unshuffle definition:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
PyTorch implementation:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp