-
Notifications
You must be signed in to change notification settings - Fork 618
[Torch] Canonicalize pool ops with single int tuple params. #4250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
if constexpr (std::is_same<AtenPoolOpT, AtenMaxPool2dOp>() || | ||
std::is_same<AtenPoolOpT, AtenMaxPool3dOp>()) { | ||
if (!matchPattern(op.getDilation(), | ||
m_TorchListOfConstantInts(dilationInts))) | ||
return rewriter.notifyMatchFailure( | ||
op, "Non-const dilation for pooling op unsupported"); | ||
|
||
if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 && | ||
strideInts.size() != 1 && dilationInts.size() != 1) { | ||
return rewriter.notifyMatchFailure( | ||
op, | ||
"Expected one of kernel/stride/padding/dilation to be singleton."); | ||
} | ||
|
||
expand(dilationInts, numSpatialDims); | ||
|
||
} else if (kernelSizeInts.size() != 1 && paddingInts.size() != 1 && | ||
strideInts.size() != 1) { | ||
return rewriter.notifyMatchFailure( | ||
op, "Expected one of kernel/stride/padding to be singleton."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this, you can also do:
bool isMaxPool = false;
if constexpr(....) {
// const dilation check
isMaxPool = true;
}
// Singleton check for dilation based on the maxpool flag.
// Singleton check for rest of the values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the singleton check for dilation cannot be decoupled from the singleton check for the other values for MaxPool
op. For MaxPool
as long as one of the params of kernel/stride/padding/dilation is singleton we can canonicalize it. For AvgPool
similarly we have to check for either of kernel/stride/padding to be non-singleton.
So the code will look like:
bool isMaxPool = false;
if constexpr(....) {
// const dilation check
isMaxPool = true;
}
if (isMaxPool) {
// check for one of kernel/stride/padding/dilation to be singleton
} else {
// one of kernel/stride/padding to be singleton
}
if (isMaxPool) {
expandDilation
}
// expand other params
Is that your suggestion?
super().__init__() | ||
self.apd = torch.nn.AvgPool3d( | ||
kernel_size=(6, 6, 6), | ||
stride=(2,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the focus of the test is a single value for the stride, then the name of the test could reflect this.
kernel_size=(6, 6, 6), | ||
stride=(2, 2, 2), | ||
padding=(1, 1, 1), | ||
dilation=(2,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal of the test is the single value dilation the name of the test should reflect it.
def __init__(self): | ||
super().__init__() | ||
self.mpd = torch.nn.MaxPool2d( | ||
kernel_size=(6,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the goal of the test is the single value kernel size the name of the test should reflect it.
self.apd = torch.nn.AvgPool3d( | ||
kernel_size=(6, 6, 6), | ||
stride=(2,), | ||
padding=(1, 1, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about a test for single value for padding?
@@ -1810,6 +1861,31 @@ def AvgPool3dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): | |||
module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) | |||
|
|||
|
|||
class AvgPool3dSingleIntTupleParamsModule(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The kernel size and dilation are tested with max pooling, and the stride with average pooling. If they share the same path this is fine. But I saw the max pooling taking a separate path for expansion. If the paths are not shared, it might be good to have test for max pooling and average pooling for all the parameters that could be expanded. If adding these tests do not increate path coverage in the source code, you can ignore this message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a second look, the dilation path is only for max pooling, average pooling does not have it. The expansion for the other 3 parameters is shared. You can ignore this message.
Fixes #3885 by repeating the single int to match with expected spatial dims.