Skip to content

[TorchToLinalg] Add lowering of torch.aten.pixel_unshuffle op to linalg dialect #4260

@alaa-ali

Description

@alaa-ali

This code snippet of pixel_unshuffle.mlir that will reproduce the issue:

module {
  func.func @main(%arg0: !torch.vtensor<[1,8,9],f32>) -> !torch.vtensor<[1,8,9],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %int1 = torch.constant.int 1
    %0 = torch.operator "torch.aten.pixel_unshuffle"(%arg0, %int1) : (!torch.vtensor<[1,8,9],f32>, !torch.int) -> !torch.vtensor<[1,8,9],f32> 
    return %0 : !torch.vtensor<[1,8,9],f32>
  }
}

By executing the following command,torch.aten.pixel_unshuffle op will not be lowered:
$torch-mlir-opt --convert-torch-to-linalg pixel_unshuffle.mlir

Proposal:
This aten.pixel_unshuffle operation can be supported by providing a torch dialect-level decomposition similar to the pixel_shuffle op.

The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
and you can find PyTorch implementation in main/aten/src/ATen/native/PixelShuffle.cpp:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp

Note that the aten.pixel_unshuffle operator decomposes into prims.split_dim, aten.permute, and prims.collapse operations.
For example, an input tensor of shape <1x8x4x4xf32> with down-sampling factor of 2, pixel_unshuffle op could be lowered following these steps:

Input: tensor<1x8x4x4xf32>, Factor: 2
Output: tensor<1x32x2x2xf32>

//Step 1: reshape: [1,8,4,4] -> [1,8,2,2,2,2]
%reshaped1 = tensor.reshape %input : tensor<1x8x4x4xf32> into tensor<1x8x2x2x2x2xf32>
//Step 2: Permute dims to move 2x2 to channels: [1,8,2,2,2,2] -> [1,8,2,2,2,2]
%transposed = linalg.transpose %reshaped1 [0,1,4,2,5,3] : tensor<1x8x2x2x2x2xf32> to tensor<1x8x2x2x2x2xf32>
// Step 3: Final reshape to [1,32,2,2]
%result = tensor.reshape %transposed : tensor<1x8x2x2x2x2xf32> into tensor<1x32x2x2xf32>

References:
PyTorch Pixel_Unshuffle definition:
https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pixel_unshuffle.html
PyTorch implementation:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/PixelShuffle.cpp

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions