Skip to content

Commit a7d65c7

Browse files
authored
[Backend Tester] Add reduction op tests (#12853)
Add tests for reduction-type ops.
1 parent 51e6626 commit a7d65c7

File tree

6 files changed

+1399
-0
lines changed

6 files changed

+1399
-0
lines changed
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List, Optional, Tuple, Union
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class AmaxModel(torch.nn.Module):
22+
def __init__(
23+
self,
24+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
25+
keepdim: bool = False,
26+
):
27+
super().__init__()
28+
self.dim = dim
29+
self.keepdim = keepdim
30+
31+
def forward(self, x):
32+
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
33+
34+
35+
@operator_test
36+
class Amax(OperatorTest):
37+
@dtype_test
38+
def test_amax_dtype(self, flow: TestFlow, dtype) -> None:
39+
self._test_op(
40+
AmaxModel().to(dtype),
41+
(torch.rand(10, 10).to(dtype),),
42+
flow,
43+
)
44+
45+
def test_amax_dim(self, flow: TestFlow) -> None:
46+
self._test_op(
47+
AmaxModel(dim=0),
48+
(torch.randn(5, 10),),
49+
flow,
50+
)
51+
52+
self._test_op(
53+
AmaxModel(dim=1),
54+
(torch.randn(5, 10),),
55+
flow,
56+
)
57+
58+
self._test_op(
59+
AmaxModel(dim=0),
60+
(torch.randn(3, 4, 5),),
61+
flow,
62+
)
63+
64+
self._test_op(
65+
AmaxModel(dim=1),
66+
(torch.randn(3, 4, 5),),
67+
flow,
68+
)
69+
70+
self._test_op(
71+
AmaxModel(dim=2),
72+
(torch.randn(3, 4, 5),),
73+
flow,
74+
)
75+
76+
self._test_op(
77+
AmaxModel(dim=1),
78+
(torch.randn(2, 3, 4, 5),),
79+
flow,
80+
)
81+
82+
self._test_op(
83+
AmaxModel(dim=-1),
84+
(torch.randn(3, 4, 5),),
85+
flow,
86+
)
87+
88+
self._test_op(
89+
AmaxModel(dim=-2),
90+
(torch.randn(3, 4, 5),),
91+
flow,
92+
)
93+
94+
def test_amax_multi_dim(self, flow: TestFlow) -> None:
95+
self._test_op(
96+
AmaxModel(dim=(0, 1)),
97+
(torch.randn(3, 4, 5),),
98+
flow,
99+
)
100+
101+
self._test_op(
102+
AmaxModel(dim=(0, 2)),
103+
(torch.randn(3, 4, 5),),
104+
flow,
105+
)
106+
107+
self._test_op(
108+
AmaxModel(dim=(1, 2)),
109+
(torch.randn(3, 4, 5),),
110+
flow,
111+
)
112+
113+
self._test_op(
114+
AmaxModel(dim=(1, 3)),
115+
(torch.randn(2, 3, 4, 5),),
116+
flow,
117+
)
118+
119+
self._test_op(
120+
AmaxModel(dim=(0, 2)),
121+
(torch.randn(2, 3, 4, 5),),
122+
flow,
123+
)
124+
125+
self._test_op(
126+
AmaxModel(dim=(-1, -3)),
127+
(torch.randn(2, 3, 4, 5),),
128+
flow,
129+
)
130+
131+
self._test_op(
132+
AmaxModel(dim=(0, 1, 2, 3)),
133+
(torch.randn(2, 3, 4, 5),),
134+
flow,
135+
)
136+
137+
def test_amax_keepdim(self, flow: TestFlow) -> None:
138+
self._test_op(
139+
AmaxModel(dim=0, keepdim=True),
140+
(torch.randn(5, 10),),
141+
flow,
142+
)
143+
144+
self._test_op(
145+
AmaxModel(dim=1, keepdim=True),
146+
(torch.randn(5, 10),),
147+
flow,
148+
)
149+
150+
self._test_op(
151+
AmaxModel(dim=1, keepdim=True),
152+
(torch.randn(3, 4, 5),),
153+
flow,
154+
)
155+
156+
self._test_op(
157+
AmaxModel(dim=2, keepdim=True),
158+
(torch.randn(2, 3, 4, 5),),
159+
flow,
160+
)
161+
162+
self._test_op(
163+
AmaxModel(dim=(1, 2), keepdim=True),
164+
(torch.randn(3, 4, 5),),
165+
flow,
166+
)
167+
168+
def test_amax_shapes(self, flow: TestFlow) -> None:
169+
self._test_op(
170+
AmaxModel(),
171+
(torch.randn(20),),
172+
flow,
173+
)
174+
self._test_op(
175+
AmaxModel(dim=0),
176+
(torch.randn(20),),
177+
flow,
178+
)
179+
180+
self._test_op(
181+
AmaxModel(),
182+
(torch.randn(5, 10),),
183+
flow,
184+
)
185+
186+
self._test_op(
187+
AmaxModel(),
188+
(torch.randn(3, 4, 5),),
189+
flow,
190+
)
191+
192+
self._test_op(
193+
AmaxModel(),
194+
(torch.randn(2, 3, 4, 5),),
195+
flow,
196+
)
197+
198+
self._test_op(
199+
AmaxModel(),
200+
(torch.randn(2, 2, 3, 4, 5),),
201+
flow,
202+
)
203+
204+
def test_amax_edge_cases(self, flow: TestFlow) -> None:
205+
x = torch.tensor([[1.0, float("inf"), 3.0], [4.0, 5.0, float("inf")]])
206+
self._test_op(
207+
AmaxModel(),
208+
(x,),
209+
flow,
210+
use_random_test_inputs=False,
211+
)
212+
self._test_op(
213+
AmaxModel(dim=0),
214+
(x,),
215+
flow,
216+
use_random_test_inputs=False,
217+
)
218+
self._test_op(
219+
AmaxModel(dim=1),
220+
(x,),
221+
flow,
222+
use_random_test_inputs=False,
223+
)
224+
225+
x = torch.tensor([[1.0, float("nan"), 3.0], [4.0, 5.0, float("nan")]])
226+
self._test_op(
227+
AmaxModel(),
228+
(x,),
229+
flow,
230+
use_random_test_inputs=False,
231+
)
232+
self._test_op(
233+
AmaxModel(dim=0),
234+
(x,),
235+
flow,
236+
use_random_test_inputs=False,
237+
)
238+
self._test_op(
239+
AmaxModel(dim=1),
240+
(x,),
241+
flow,
242+
use_random_test_inputs=False,
243+
)
244+
245+
def test_amax_scalar(self, flow: TestFlow) -> None:
246+
self._test_op(
247+
AmaxModel(),
248+
(torch.tensor([5.0]),),
249+
flow,
250+
)
251+
self._test_op(
252+
AmaxModel(dim=0),
253+
(torch.tensor([5.0]),),
254+
flow,
255+
)

0 commit comments

Comments
 (0)