Skip to content

Provide torch to linalg lowering for the torch.aten.channel_shuffle operation #4243

@ivangarcia44

Description

@ivangarcia44

Save this code in channelshuffle.mlir

module {
  func.func @main(%arg0: !torch.vtensor<[?,1,8],f32>) -> !torch.vtensor<[?,1,8],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %int1 = torch.constant.int 1
    %0 = torch.operator "torch.aten.channel_shuffle"(%arg0, %int1) : (!torch.vtensor<[?,1,8],f32>, !torch.int) -> !torch.vtensor<[?,1,8],f32> 
    return %0 : !torch.vtensor<[?,1,8],f32>
  }
}

And execute:
torch-mlir-opt --convert-torch-to-linalg channelshuffle.mlir

Notice that the torch.aten.channel_shuffle operation is not lowered.

Proposal:

This operator can be supported by doing torch dialect level decomposition like the pixel_shuffle operation.

The decomposition is based on this specification:
https://docs.pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
and implementation:
aten/src/ATen/native/ChanelShuffle.cpp
https://github.com/pytorch/pytorch/blob/23491519d288dedb2a54cfad5fef7fcb2ad8eade/aten/src/ATen/native/ChanelShuffle.cpp#L4

Note that the operator consists of an expansion, expanded channel dimensions permute, and contraction of channel dimensions back to the original size. For example, for an input array of shape 1x8x4x4 with a group size of 2 would generate the MLIR linalg code below.

module {
  func.func @channel_shuffle(%arg0: !torch.vtensor<[1, 8, 4, 4], f32>) -> !torch.vtensor<[1, 8, 4, 4], f32> {
    %c0 = torch.constant.int 0
    %c1 = torch.constant.int 1
    %c2 = torch.constant.int 2
    %c3 = torch.constant.int 3
    %c4 = torch.constant.int 4
    %dims = torch.prim.ListConstruct %c0, %c2, %c1, %c3, %c4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list

    %reshaped = torch.prims.split_dim %arg0, %c1, %c2 : !torch.vtensor<[1, 8, 4, 4], f32>, !torch.int, !torch.int -> !torch.vtensor<[1, 4, 2, 4, 4], f32>

    %permuted = torch.aten.permute %reshaped, %dims : !torch.vtensor<[1, 4, 2, 4, 4], f32>, !torch.list -> !torch.vtensor<[1, 2, 4, 4, 4], f32>

    %collapsed = torch.prims.collapse %permuted, %c1, %c2 : !torch.vtensor<[1, 2, 4, 4, 4], f32>, !torch.int, !torch.int -> !torch.vtensor<[1, 8, 4, 4], f32>

    return %collapsed : !torch.vtensor<[1, 8, 4, 4], f32>
  }
}

References:
PyTorch ChannelShuffle definition:
https://docs.pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices (2017):
https://arxiv.org/pdf/1707.01083
A Lightweight Dendritic ShuffleNet for Medical Image Classification (2025)
https://www.jstage.jst.go.jp/article/transinf/advpub/0/advpub_2024EDP7059/_pdf
PyTorch implementation:
aten/src/ATen/native/ChanelShuffle.cpp
https://github.com/pytorch/pytorch/blob/23491519d288dedb2a54cfad5fef7fcb2ad8eade/aten/src/ATen/native/ChanelShuffle.cpp#L4

@silvasean @rsuderman @zjgarvey

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions