-
Notifications
You must be signed in to change notification settings - Fork 618
Description
Save this code in channelshuffle.mlir
module { func.func @main(%arg0: !torch.vtensor<[?,1,8],f32>) -> !torch.vtensor<[?,1,8],f32> attributes {torch.assume_strict_symbolic_shapes} { %int1 = torch.constant.int 1 %0 = torch.operator "torch.aten.channel_shuffle"(%arg0, %int1) : (!torch.vtensor<[?,1,8],f32>, !torch.int) -> !torch.vtensor<[?,1,8],f32> return %0 : !torch.vtensor<[?,1,8],f32> } }
And execute:
torch-mlir-opt --convert-torch-to-linalg channelshuffle.mlir
Notice that the torch.aten.channel_shuffle operation is not lowered.
Proposal:
This operator can be supported by doing torch dialect level decomposition like the pixel_shuffle operation.
The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
and implementation:
aten/src/ATen/native/ChanelShuffle.cpp
https://github.com/pytorch/pytorch/blob/23491519d288dedb2a54cfad5fef7fcb2ad8eade/aten/src/ATen/native/ChanelShuffle.cpp#L4
Note that the operator consists of an expansion, expanded channel dimensions permute, and contraction of channel dimensions back to the original size. For example, for an input array of shape 1x8x4x4 with a group size of 2 would generate the MLIR linalg code below.
module { func.func @channel_shuffle(%arg0: !torch.vtensor<[1, 8, 4, 4], f32>) -> !torch.vtensor<[1, 8, 4, 4], f32> { %c0 = torch.constant.int 0 %c1 = torch.constant.int 1 %c2 = torch.constant.int 2 %c3 = torch.constant.int 3 %c4 = torch.constant.int 4 %dims = torch.prim.ListConstruct %c0, %c2, %c1, %c3, %c4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %reshaped = torch.prims.split_dim %arg0, %c1, %c2 : !torch.vtensor<[1, 8, 4, 4], f32>, !torch.int, !torch.int -> !torch.vtensor<[1, 4, 2, 4, 4], f32> %permuted = torch.aten.permute %reshaped, %dims : !torch.vtensor<[1, 4, 2, 4, 4], f32>, !torch.list -> !torch.vtensor<[1, 2, 4, 4, 4], f32> %collapsed = torch.prims.collapse %permuted, %c1, %c2 : !torch.vtensor<[1, 2, 4, 4, 4], f32>, !torch.int, !torch.int -> !torch.vtensor<[1, 8, 4, 4], f32> return %collapsed : !torch.vtensor<[1, 8, 4, 4], f32> } }
References:
PyTorch ChannelShuffle definition:
https://docs.pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices (2017):
https://arxiv.org/pdf/1707.01083
A Lightweight Dendritic ShuffleNet for Medical Image Classification (2025)
https://www.jstage.jst.go.jp/article/transinf/advpub/0/advpub_2024EDP7059/_pdf
PyTorch implementation:
aten/src/ATen/native/ChanelShuffle.cpp
https://github.com/pytorch/pytorch/blob/23491519d288dedb2a54cfad5fef7fcb2ad8eade/aten/src/ATen/native/ChanelShuffle.cpp#L4