Skip to content

Commit 0479dcd

Browse files
authored
[Backend Tester] Add slice and reshape tests (#12851)
Add tests for a view-type ops, cat, and slice.
1 parent ec23e87 commit 0479dcd

File tree

10 files changed

+1126
-0
lines changed

10 files changed

+1126
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class CatModel(torch.nn.Module):
21+
def __init__(self, dim: int = 0):
22+
super().__init__()
23+
self.dim = dim
24+
25+
def forward(self, x1, x2, x3):
26+
return torch.cat([x1, x2, x3], dim=self.dim)
27+
28+
29+
@operator_test
30+
class Cat(OperatorTest):
31+
@dtype_test
32+
def test_cat_dtype(self, flow: TestFlow, dtype) -> None:
33+
self._test_op(
34+
CatModel(),
35+
(
36+
torch.rand(8, 32).to(dtype),
37+
torch.rand(12, 32).to(dtype),
38+
torch.rand(16, 32).to(dtype),
39+
),
40+
flow,
41+
)
42+
43+
def test_cat_dimensions(self, flow: TestFlow) -> None:
44+
self._test_op(
45+
CatModel(dim=0),
46+
(
47+
torch.randn(8, 32),
48+
torch.randn(12, 32),
49+
torch.randn(16, 32),
50+
),
51+
flow,
52+
)
53+
54+
self._test_op(
55+
CatModel(dim=1),
56+
(
57+
torch.randn(16, 8),
58+
torch.randn(16, 12),
59+
torch.randn(16, 16),
60+
),
61+
flow,
62+
)
63+
64+
self._test_op(
65+
CatModel(dim=2),
66+
(
67+
torch.randn(4, 8, 4),
68+
torch.randn(4, 8, 8),
69+
torch.randn(4, 8, 12),
70+
),
71+
flow,
72+
)
73+
74+
def test_cat_negative_dim(self, flow: TestFlow) -> None:
75+
self._test_op(
76+
CatModel(dim=-1),
77+
(
78+
torch.randn(16, 8),
79+
torch.randn(16, 12),
80+
torch.randn(16, 16),
81+
),
82+
flow,
83+
)
84+
85+
self._test_op(
86+
CatModel(dim=-2),
87+
(
88+
torch.randn(8, 32),
89+
torch.randn(12, 32),
90+
torch.randn(16, 32),
91+
),
92+
flow,
93+
)
94+
95+
def test_cat_different_shapes(self, flow: TestFlow) -> None:
96+
self._test_op(
97+
CatModel(),
98+
(
99+
torch.randn(128),
100+
torch.randn(256),
101+
torch.randn(384),
102+
),
103+
flow,
104+
)
105+
106+
self._test_op(
107+
CatModel(dim=0),
108+
(
109+
torch.randn(4, 8, 16),
110+
torch.randn(8, 8, 16),
111+
torch.randn(12, 8, 16),
112+
),
113+
flow,
114+
)
115+
116+
self._test_op(
117+
CatModel(dim=1),
118+
(
119+
torch.randn(8, 4, 16),
120+
torch.randn(8, 8, 16),
121+
torch.randn(8, 12, 16),
122+
),
123+
flow,
124+
)
125+
126+
self._test_op(
127+
CatModel(dim=2),
128+
(
129+
torch.randn(8, 12, 4),
130+
torch.randn(8, 12, 8),
131+
torch.randn(8, 12, 12),
132+
),
133+
flow,
134+
)
135+
136+
def test_cat_broadcast(self, flow: TestFlow) -> None:
137+
self._test_op(
138+
CatModel(dim=0),
139+
(
140+
torch.randn(2, 16, 32),
141+
torch.randn(4, 16, 32),
142+
torch.randn(6, 16, 32),
143+
),
144+
flow,
145+
)
146+
147+
self._test_op(
148+
CatModel(dim=1),
149+
(
150+
torch.randn(8, 8, 16),
151+
torch.randn(8, 16, 16),
152+
torch.randn(8, 24, 16),
153+
),
154+
flow,
155+
)
156+
157+
self._test_op(
158+
CatModel(dim=2),
159+
(
160+
torch.randn(4, 16, 8),
161+
torch.randn(4, 16, 16),
162+
torch.randn(4, 16, 24),
163+
),
164+
flow,
165+
)
166+
167+
def test_cat_same_shapes(self, flow: TestFlow) -> None:
168+
self._test_op(
169+
CatModel(),
170+
(
171+
torch.randn(8, 32),
172+
torch.randn(8, 32),
173+
torch.randn(8, 32),
174+
),
175+
flow,
176+
)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class ExpandModel(torch.nn.Module):
22+
def __init__(self, shape: List[int]):
23+
super().__init__()
24+
self.shape = shape
25+
26+
def forward(self, x):
27+
return x.expand(self.shape)
28+
29+
30+
@operator_test
31+
class Expand(OperatorTest):
32+
@dtype_test
33+
def test_expand_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
ExpandModel(shape=[8, 32]),
36+
(torch.rand(1, 32).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_expand_dimensions(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
ExpandModel(shape=[8, 32]),
43+
(torch.randn(1, 32),),
44+
flow,
45+
)
46+
47+
self._test_op(
48+
ExpandModel(shape=[16, 20]),
49+
(torch.randn(1, 1),),
50+
flow,
51+
)
52+
53+
self._test_op(
54+
ExpandModel(shape=[4, 1, 32]),
55+
(torch.randn(1, 32),),
56+
flow,
57+
)
58+
59+
self._test_op(
60+
ExpandModel(shape=[8, 4, 16]),
61+
(torch.randn(8, 1, 16),),
62+
flow,
63+
)
64+
65+
self._test_op(
66+
ExpandModel(shape=[6, 16, 8]),
67+
(torch.randn(6, 16, 1),),
68+
flow,
69+
)
70+
71+
def test_expand_keep_original_size(self, flow: TestFlow) -> None:
72+
self._test_op(
73+
ExpandModel(shape=[8, -1]),
74+
(torch.randn(1, 32),),
75+
flow,
76+
)
77+
78+
self._test_op(
79+
ExpandModel(shape=[-1, 32]),
80+
(torch.randn(4, 1),),
81+
flow,
82+
)
83+
84+
self._test_op(
85+
ExpandModel(shape=[-1, 16, -1]),
86+
(torch.randn(4, 1, 8),),
87+
flow,
88+
)
89+
90+
def test_expand_rank_increase(self, flow: TestFlow) -> None:
91+
# Test expanding 2D tensor to 3D
92+
self._test_op(
93+
ExpandModel(shape=[6, 8, 16]),
94+
(torch.randn(8, 16),),
95+
flow,
96+
)
97+
98+
# Test expanding 2D tensor to 4D
99+
self._test_op(
100+
ExpandModel(shape=[3, 4, 8, 16]),
101+
(torch.randn(8, 16),),
102+
flow,
103+
)
104+
105+
def test_expand_singleton_dimensions(self, flow: TestFlow) -> None:
106+
self._test_op(
107+
ExpandModel(shape=[512]),
108+
(torch.randn(1),),
109+
flow,
110+
)
111+
112+
self._test_op(
113+
ExpandModel(shape=[16, 20]),
114+
(torch.randn(1, 1),),
115+
flow,
116+
)
117+
118+
self._test_op(
119+
ExpandModel(shape=[8, 32]),
120+
(torch.randn(32),),
121+
flow,
122+
)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class ReshapeModel(torch.nn.Module):
22+
def __init__(self, shape: List[int]):
23+
super().__init__()
24+
self.shape = shape
25+
26+
def forward(self, x):
27+
return torch.reshape(x, self.shape)
28+
29+
30+
@operator_test
31+
class Reshape(OperatorTest):
32+
@dtype_test
33+
def test_reshape_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
ReshapeModel(shape=[3, 5]),
36+
(torch.rand(15).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_reshape_dimensions(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
ReshapeModel(shape=[3, 5]),
43+
(torch.randn(15),),
44+
flow,
45+
)
46+
47+
self._test_op(
48+
ReshapeModel(shape=[20]),
49+
(torch.randn(4, 5),),
50+
flow,
51+
)
52+
53+
self._test_op(
54+
ReshapeModel(shape=[2, 2, 5]),
55+
(torch.randn(4, 5),),
56+
flow,
57+
)
58+
59+
self._test_op(
60+
ReshapeModel(shape=[6, 4]),
61+
(torch.randn(3, 2, 4),),
62+
flow,
63+
)
64+
65+
def test_reshape_inferred_dimension(self, flow: TestFlow) -> None:
66+
self._test_op(
67+
ReshapeModel(shape=[3, -1]),
68+
(torch.randn(15),),
69+
flow,
70+
)
71+
72+
self._test_op(
73+
ReshapeModel(shape=[-1, 5]),
74+
(torch.randn(15),),
75+
flow,
76+
)
77+
78+
self._test_op(
79+
ReshapeModel(shape=[2, -1, 3]),
80+
(torch.randn(24),),
81+
flow,
82+
)

0 commit comments

Comments
 (0)