Skip to content

[Torch] Canonicalize pool ops with single int tuple params. #4250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

sahas3
Copy link
Member

@sahas3 sahas3 commented Jul 1, 2025

Fixes #3885 by repeating the single int to match with expected spatial dims.

@sahas3 sahas3 requested review from vivekkhandelwal1 and sjarus July 2, 2025 01:57
Comment on lines +5634 to +5654
if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() ||
std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) {
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts)))
return rewriter.notifyMatchFailure(
op, "Non-const dilation for pooling op unsupported");

if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
strideInts.size() != 1 && dilationInts.size() != 1) {
return rewriter.notifyMatchFailure(
op,
"Expected one of kernel/stride/padding/dilation to be singleton.");
}

expand(dilationInts, numSpatialDims);

} else if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 &&
strideInts.size() != 1) {
return rewriter.notifyMatchFailure(
op, "Expected one of kernel/stride/padding to be singleton.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this, you can also do:

bool isMaxPool = false;
if constexpr(....) {
    // const dilation check
    isMaxPool = true;
}

// Singleton check for dilation based on the maxpool flag.
// Singleton check for rest of the values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the singleton check for dilation cannot be decoupled from the singleton check for the other values for MaxPool op. For MaxPool as long as one of the params of kernel/stride/padding/dilation is singleton we can canonicalize it. For AvgPool similarly we have to check for either of kernel/stride/padding to be non-singleton.

So the code will look like:

bool isMaxPool = false;
if constexpr(....) { 
    // const dilation check 
    isMaxPool = true; 
}

if (isMaxPool) {
    // check for one of kernel/stride/padding/dilation to be singleton
} else {
   // one of kernel/stride/padding to be singleton
}

if (isMaxPool) {
    expandDilation
} 

// expand other params

Is that your suggestion?

super().__init__()
self.apd = torch.nn.AvgPool3d(
kernel_size=(6, 6, 6),
stride=(2,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the focus of the test is a single value for the stride, then the name of the test could reflect this.

kernel_size=(6, 6, 6),
stride=(2, 2, 2),
padding=(1, 1, 1),
dilation=(2,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal of the test is the single value dilation the name of the test should reflect it.

def __init__(self):
super().__init__()
self.mpd = torch.nn.MaxPool2d(
kernel_size=(6,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal of the test is the single value kernel size the name of the test should reflect it.

self.apd = torch.nn.AvgPool3d(
kernel_size=(6, 6, 6),
stride=(2,),
padding=(1, 1, 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a test for single value for padding?

@@ -1810,6 +1861,31 @@ def AvgPool3dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 12, 12, 12, low=-1))


class AvgPool3dSingleIntTupleParamsModule(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel size and dilation are tested with max pooling, and the stride with average pooling. If they share the same path this is fine. But I saw the max pooling taking a separate path for expansion. If the paths are not shared, it might be good to have test for max pooling and average pooling for all the parameters that could be expanded. If adding these tests do not increate path coverage in the source code, you can ignore this message.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a second look, the dilation path is only for max pooling, average pooling does not have it. The expansion for the other 3 parameters is shared. You can ignore this message.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Assertion when lowering from Torch IR for AvgPool2d when kernel is a tuple of a single int
4 participants