-
Notifications
You must be signed in to change notification settings - Fork 640
[Backend Tester] Add permute, transpose, and masked_fill tests #12850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
fd73fb9
Update
GregoryComer 2de680e
Update
GregoryComer 8b11366
Update
GregoryComer ca1b887
Update
GregoryComer 6acef0f
Update
GregoryComer cc3332e
Update
GregoryComer 4f4d58d
Update
GregoryComer 346ab2e
Update
GregoryComer cfc76a3
Update
GregoryComer e2ccfec
Update
GregoryComer 3d398bc
Update
GregoryComer c50e393
Update
GregoryComer e99cd87
Update
GregoryComer f41e9a7
Update
GregoryComer 4193235
Update
GregoryComer a265cca
Update
GregoryComer 69742a8
Update
GregoryComer 831f814
Update
GregoryComer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
from typing import Union | ||
|
||
import torch | ||
from executorch.backends.test.suite.flow import TestFlow | ||
|
||
from executorch.backends.test.suite.operators import ( | ||
dtype_test, | ||
operator_test, | ||
OperatorTest, | ||
) | ||
|
||
|
||
class MaskedFillModel(torch.nn.Module): | ||
def __init__(self, value: Union[float, int]): | ||
super().__init__() | ||
self.value = value | ||
|
||
def forward(self, x, mask): | ||
return x.masked_fill(mask, self.value) | ||
|
||
|
||
@operator_test | ||
class MaskedFill(OperatorTest): | ||
@dtype_test | ||
def test_masked_fill_dtype(self, flow: TestFlow, dtype) -> None: | ||
mask = torch.randint(0, 2, (16, 32), dtype=torch.bool) | ||
self._test_op( | ||
MaskedFillModel(value=0.0), | ||
( | ||
torch.rand(16, 32).to(dtype), | ||
mask, | ||
), | ||
flow, | ||
) | ||
|
||
def test_masked_fill_different_values(self, flow: TestFlow) -> None: | ||
mask = torch.randint(0, 2, (16, 32), dtype=torch.bool) | ||
|
||
self._test_op( | ||
MaskedFillModel(value=5.0), | ||
( | ||
torch.randn(16, 32), | ||
mask, | ||
), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
MaskedFillModel(value=-5.0), | ||
( | ||
torch.randn(16, 32), | ||
mask, | ||
), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
MaskedFillModel(value=1), | ||
( | ||
torch.randn(16, 32), | ||
mask, | ||
), | ||
flow, | ||
) | ||
|
||
def test_masked_fill_different_shapes(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
MaskedFillModel(value=0.0), | ||
( | ||
torch.randn(512), | ||
torch.randint(0, 2, (512,), dtype=torch.bool), | ||
), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
MaskedFillModel(value=0.0), | ||
( | ||
torch.randn(4, 8, 16), | ||
torch.randint(0, 2, (4, 8, 16), dtype=torch.bool), | ||
), | ||
flow, | ||
) | ||
|
||
def test_masked_fill_broadcast(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
MaskedFillModel(value=0.0), | ||
( | ||
torch.randn(16, 32), | ||
torch.randint(0, 2, (32,), dtype=torch.bool), | ||
), | ||
flow, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
from typing import List | ||
|
||
import torch | ||
from executorch.backends.test.suite.flow import TestFlow | ||
|
||
from executorch.backends.test.suite.operators import ( | ||
dtype_test, | ||
operator_test, | ||
OperatorTest, | ||
) | ||
|
||
|
||
class PermuteModel(torch.nn.Module): | ||
def __init__(self, dims: List[int]): | ||
super().__init__() | ||
self.dims = dims | ||
|
||
def forward(self, x): | ||
return x.permute(self.dims) | ||
|
||
|
||
@operator_test | ||
class Permute(OperatorTest): | ||
@dtype_test | ||
def test_permute_dtype(self, flow: TestFlow, dtype) -> None: | ||
self._test_op( | ||
PermuteModel(dims=[1, 0]), | ||
(torch.rand(20, 32).to(dtype),), | ||
flow, | ||
) | ||
|
||
def test_permute_3d(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
PermuteModel(dims=[2, 0, 1]), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
PermuteModel(dims=[1, 2, 0]), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
PermuteModel(dims=[0, 2, 1]), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
def test_permute_4d(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
PermuteModel(dims=[3, 2, 1, 0]), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
PermuteModel(dims=[0, 2, 1, 3]), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
def test_permute_identity(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
PermuteModel(dims=[0, 1]), | ||
(torch.randn(20, 32),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
PermuteModel(dims=[0, 1, 2]), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
def test_permute_negative_dims(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
PermuteModel(dims=[-1, -3, -2, -4]), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
PermuteModel(dims=[-4, -2, -3, -1]), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
def test_permute_different_shapes(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
PermuteModel(dims=[0]), | ||
(torch.randn(512),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
PermuteModel(dims=[4, 3, 2, 1, 0]), | ||
(torch.randn(2, 3, 4, 5, 6),), | ||
flow, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
|
||
import torch | ||
from executorch.backends.test.suite.flow import TestFlow | ||
|
||
from executorch.backends.test.suite.operators import ( | ||
dtype_test, | ||
operator_test, | ||
OperatorTest, | ||
) | ||
|
||
|
||
class TransposeModel(torch.nn.Module): | ||
def __init__(self, dim0: int, dim1: int): | ||
super().__init__() | ||
self.dim0 = dim0 | ||
self.dim1 = dim1 | ||
|
||
def forward(self, x): | ||
return torch.transpose(x, self.dim0, self.dim1) | ||
|
||
|
||
@operator_test | ||
class Transpose(OperatorTest): | ||
@dtype_test | ||
def test_transpose_dtype(self, flow: TestFlow, dtype) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=0, dim1=1), | ||
(torch.rand(20, 32).to(dtype),), | ||
flow, | ||
) | ||
|
||
def test_transpose_basic(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=0, dim1=1), | ||
(torch.randn(20, 32),), | ||
flow, | ||
) | ||
|
||
def test_transpose_3d(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=0, dim1=1), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=0, dim1=2), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=1, dim1=2), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
def test_transpose_4d(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=0, dim1=3), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=1, dim1=2), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
def test_transpose_identity(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=0, dim1=0), | ||
(torch.randn(20, 32),), | ||
flow, | ||
) | ||
self._test_op( | ||
TransposeModel(dim0=1, dim1=1), | ||
(torch.randn(20, 32),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=0, dim1=0), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
self._test_op( | ||
TransposeModel(dim0=1, dim1=1), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
self._test_op( | ||
TransposeModel(dim0=2, dim1=2), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
def test_transpose_negative_dims(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=-3, dim1=-1), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=-2, dim1=-1), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
def test_transpose_different_shapes(self, flow: TestFlow) -> None: | ||
self._test_op( | ||
TransposeModel(dim0=0, dim1=1), | ||
(torch.randn(20, 32),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=0, dim1=2), | ||
(torch.randn(8, 10, 12),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=1, dim1=3), | ||
(torch.randn(4, 6, 8, 10),), | ||
flow, | ||
) | ||
|
||
self._test_op( | ||
TransposeModel(dim0=0, dim1=4), | ||
(torch.randn(2, 3, 4, 5, 6),), | ||
flow, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some test generator or Facto or even a meta class to generate permutations for permute might be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Should be easy to use itertools for this, I think. I'll take this as a follow-up.