Skip to content

Commit 73a523d

Browse files
committed
[Backend Tester] Add upsample tests
ghstack-source-id: 9de53e0 ghstack-comment-id: 3116317107 Pull-Request: #12856
1 parent 15f6345 commit 73a523d

File tree

2 files changed

+518
-0
lines changed

2 files changed

+518
-0
lines changed
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Optional, Tuple, Union
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class ModelWithSize(torch.nn.Module):
22+
def __init__(
23+
self,
24+
size: Optional[Tuple[int, int]] = None,
25+
align_corners: Optional[bool] = None,
26+
):
27+
super().__init__()
28+
self.size = size
29+
self.align_corners = align_corners
30+
31+
def forward(self, x):
32+
return torch.nn.functional.interpolate(
33+
x, size=self.size, mode="bilinear", align_corners=self.align_corners
34+
)
35+
36+
37+
class ModelWithScale(torch.nn.Module):
38+
def __init__(
39+
self,
40+
scale_factor: Union[float, Tuple[float, float]] = 2.0,
41+
align_corners: Optional[bool] = None,
42+
):
43+
super().__init__()
44+
self.scale_factor = scale_factor
45+
self.align_corners = align_corners
46+
47+
def forward(self, x):
48+
return torch.nn.functional.interpolate(
49+
x,
50+
scale_factor=self.scale_factor,
51+
mode="bilinear",
52+
align_corners=self.align_corners,
53+
)
54+
55+
56+
@operator_test
57+
class TestUpsampleBilinear2d(OperatorTest):
58+
@dtype_test
59+
def test_upsample_bilinear2d_dtype(self, flow: TestFlow, dtype) -> None:
60+
# Input shape: (batch_size, channels, height, width)
61+
model = ModelWithSize(size=(10, 10), align_corners=False).to(dtype)
62+
self._test_op(model, (torch.rand(2, 3, 5, 5).to(dtype),), flow)
63+
64+
def test_upsample_bilinear2d_basic(self, flow: TestFlow) -> None:
65+
# Basic test with default parameters
66+
self._test_op(
67+
ModelWithSize(size=(10, 10), align_corners=False),
68+
(torch.randn(2, 3, 5, 5),),
69+
flow,
70+
)
71+
self._test_op(
72+
ModelWithSize(size=(10, 10), align_corners=True),
73+
(torch.randn(2, 3, 5, 5),),
74+
flow,
75+
)
76+
77+
def test_upsample_bilinear2d_sizes(self, flow: TestFlow) -> None:
78+
# Test with different input and output sizes
79+
80+
# Small input, larger output
81+
self._test_op(
82+
ModelWithSize(size=(8, 8), align_corners=False),
83+
(torch.randn(1, 2, 4, 4),),
84+
flow,
85+
)
86+
self._test_op(
87+
ModelWithSize(size=(8, 8), align_corners=True),
88+
(torch.randn(1, 2, 4, 4),),
89+
flow,
90+
)
91+
92+
# Larger input, even larger output
93+
self._test_op(
94+
ModelWithSize(size=(16, 16), align_corners=False),
95+
(torch.randn(1, 2, 8, 8),),
96+
flow,
97+
)
98+
self._test_op(
99+
ModelWithSize(size=(16, 16), align_corners=True),
100+
(torch.randn(1, 2, 8, 8),),
101+
flow,
102+
)
103+
104+
# Different height and width
105+
self._test_op(
106+
ModelWithSize(size=(16, 8), align_corners=False),
107+
(torch.randn(1, 2, 8, 4),),
108+
flow,
109+
)
110+
self._test_op(
111+
ModelWithSize(size=(16, 8), align_corners=True),
112+
(torch.randn(1, 2, 8, 4),),
113+
flow,
114+
)
115+
116+
# Asymmetric upsampling
117+
self._test_op(
118+
ModelWithSize(size=(20, 10), align_corners=False),
119+
(torch.randn(1, 2, 5, 5),),
120+
flow,
121+
)
122+
self._test_op(
123+
ModelWithSize(size=(20, 10), align_corners=True),
124+
(torch.randn(1, 2, 5, 5),),
125+
flow,
126+
)
127+
128+
def test_upsample_bilinear2d_scale_factors(self, flow: TestFlow) -> None:
129+
# Test with different scale factors
130+
131+
# Scale by 2
132+
self._test_op(
133+
ModelWithScale(scale_factor=2.0, align_corners=False),
134+
(torch.randn(1, 2, 5, 5),),
135+
flow,
136+
)
137+
self._test_op(
138+
ModelWithScale(scale_factor=2.0, align_corners=True),
139+
(torch.randn(1, 2, 5, 5),),
140+
flow,
141+
)
142+
143+
# Scale by 3
144+
self._test_op(
145+
ModelWithScale(scale_factor=3.0, align_corners=False),
146+
(torch.randn(1, 2, 5, 5),),
147+
flow,
148+
)
149+
self._test_op(
150+
ModelWithScale(scale_factor=3.0, align_corners=True),
151+
(torch.randn(1, 2, 5, 5),),
152+
flow,
153+
)
154+
155+
# Scale by 1.5
156+
self._test_op(
157+
ModelWithScale(scale_factor=1.5, align_corners=False),
158+
(torch.randn(1, 2, 6, 6),),
159+
flow,
160+
)
161+
self._test_op(
162+
ModelWithScale(scale_factor=1.5, align_corners=True),
163+
(torch.randn(1, 2, 6, 6),),
164+
flow,
165+
)
166+
167+
# Different scales for height and width
168+
self._test_op(
169+
ModelWithScale(scale_factor=(2.0, 1.5), align_corners=False),
170+
(torch.randn(1, 2, 5, 6),),
171+
flow,
172+
generate_random_test_inputs=False,
173+
)
174+
self._test_op(
175+
ModelWithScale(scale_factor=(2.0, 1.5), align_corners=True),
176+
(torch.randn(1, 2, 5, 6),),
177+
flow,
178+
generate_random_test_inputs=False,
179+
)
180+
181+
def test_upsample_bilinear2d_batch_sizes(self, flow: TestFlow) -> None:
182+
# Test with different batch sizes
183+
self._test_op(
184+
ModelWithSize(size=(10, 10), align_corners=False),
185+
(torch.randn(1, 3, 5, 5),),
186+
flow,
187+
)
188+
self._test_op(
189+
ModelWithSize(size=(10, 10), align_corners=False),
190+
(torch.randn(4, 3, 5, 5),),
191+
flow,
192+
)
193+
self._test_op(
194+
ModelWithSize(size=(10, 10), align_corners=False),
195+
(torch.randn(8, 3, 5, 5),),
196+
flow,
197+
)
198+
199+
def test_upsample_bilinear2d_channels(self, flow: TestFlow) -> None:
200+
# Test with different numbers of channels
201+
self._test_op(
202+
ModelWithSize(size=(10, 10), align_corners=False),
203+
(torch.randn(2, 1, 5, 5),),
204+
flow,
205+
) # Grayscale
206+
self._test_op(
207+
ModelWithSize(size=(10, 10), align_corners=False),
208+
(torch.randn(2, 3, 5, 5),),
209+
flow,
210+
) # RGB
211+
self._test_op(
212+
ModelWithSize(size=(10, 10), align_corners=False),
213+
(torch.randn(2, 4, 5, 5),),
214+
flow,
215+
) # RGBA
216+
self._test_op(
217+
ModelWithSize(size=(10, 10), align_corners=False),
218+
(torch.randn(2, 16, 5, 5),),
219+
flow,
220+
) # Multi-channel
221+
222+
def test_upsample_bilinear2d_same_size(self, flow: TestFlow) -> None:
223+
# Test with output size same as input size (should be identity)
224+
self._test_op(
225+
ModelWithSize(size=(5, 5), align_corners=False),
226+
(torch.randn(2, 3, 5, 5),),
227+
flow,
228+
generate_random_test_inputs=False,
229+
)
230+
self._test_op(
231+
ModelWithSize(size=(5, 5), align_corners=True),
232+
(torch.randn(2, 3, 5, 5),),
233+
flow,
234+
generate_random_test_inputs=False,
235+
)
236+
self._test_op(
237+
ModelWithScale(scale_factor=1.0, align_corners=False),
238+
(torch.randn(2, 3, 5, 5),),
239+
flow,
240+
generate_random_test_inputs=False,
241+
)
242+
self._test_op(
243+
ModelWithScale(scale_factor=1.0, align_corners=True),
244+
(torch.randn(2, 3, 5, 5),),
245+
flow,
246+
generate_random_test_inputs=False,
247+
)
248+
249+
def test_upsample_bilinear2d_downsampling(self, flow: TestFlow) -> None:
250+
# Test downsampling
251+
self._test_op(
252+
ModelWithSize(size=(4, 4), align_corners=False),
253+
(torch.randn(2, 3, 8, 8),),
254+
flow,
255+
)
256+
self._test_op(
257+
ModelWithSize(size=(4, 4), align_corners=True),
258+
(torch.randn(2, 3, 8, 8),),
259+
flow,
260+
)
261+
self._test_op(
262+
ModelWithScale(scale_factor=0.5, align_corners=False),
263+
(torch.randn(2, 3, 8, 8),),
264+
flow,
265+
generate_random_test_inputs=False,
266+
)
267+
self._test_op(
268+
ModelWithScale(scale_factor=0.5, align_corners=True),
269+
(torch.randn(2, 3, 8, 8),),
270+
flow,
271+
generate_random_test_inputs=False,
272+
)
273+
274+
# Test with non-integer downsampling factor
275+
self._test_op(
276+
ModelWithScale(scale_factor=0.75, align_corners=False),
277+
(torch.randn(2, 3, 8, 8),),
278+
flow,
279+
generate_random_test_inputs=False,
280+
)
281+
self._test_op(
282+
ModelWithScale(scale_factor=0.75, align_corners=True),
283+
(torch.randn(2, 3, 8, 8),),
284+
flow,
285+
generate_random_test_inputs=False,
286+
)
287+
288+
def test_upsample_bilinear2d_large_scale(self, flow: TestFlow) -> None:
289+
# Test with large scale factor
290+
self._test_op(
291+
ModelWithScale(scale_factor=4.0, align_corners=False),
292+
(torch.randn(1, 2, 4, 4),),
293+
flow,
294+
generate_random_test_inputs=False,
295+
)
296+
self._test_op(
297+
ModelWithScale(scale_factor=4.0, align_corners=True),
298+
(torch.randn(1, 2, 4, 4),),
299+
flow,
300+
generate_random_test_inputs=False,
301+
)
302+
303+
def test_upsample_bilinear2d_non_square(self, flow: TestFlow) -> None:
304+
# Test with non-square input
305+
self._test_op(
306+
ModelWithSize(size=(10, 20), align_corners=False),
307+
(torch.randn(2, 3, 5, 10),),
308+
flow,
309+
)
310+
self._test_op(
311+
ModelWithSize(size=(10, 20), align_corners=True),
312+
(torch.randn(2, 3, 5, 10),),
313+
flow,
314+
)
315+
self._test_op(
316+
ModelWithScale(scale_factor=2.0, align_corners=False),
317+
(torch.randn(2, 3, 5, 10),),
318+
flow,
319+
)
320+
self._test_op(
321+
ModelWithScale(scale_factor=2.0, align_corners=True),
322+
(torch.randn(2, 3, 5, 10),),
323+
flow,
324+
)
325+
326+
def test_upsample_bilinear2d_odd_sizes(self, flow: TestFlow) -> None:
327+
# Test with odd input and output sizes (where interpolation behavior might be more noticeable)
328+
self._test_op(
329+
ModelWithSize(size=(9, 9), align_corners=False),
330+
(torch.randn(2, 3, 5, 5),),
331+
flow,
332+
)
333+
self._test_op(
334+
ModelWithSize(size=(9, 9), align_corners=True),
335+
(torch.randn(2, 3, 5, 5),),
336+
flow,
337+
)
338+
self._test_op(
339+
ModelWithSize(size=(7, 7), align_corners=False),
340+
(torch.randn(2, 3, 3, 3),),
341+
flow,
342+
)
343+
self._test_op(
344+
ModelWithSize(size=(7, 7), align_corners=True),
345+
(torch.randn(2, 3, 3, 3),),
346+
flow,
347+
)
348+
self._test_op(
349+
ModelWithScale(scale_factor=1.5, align_corners=False),
350+
(torch.randn(2, 3, 5, 5),),
351+
flow,
352+
generate_random_test_inputs=False,
353+
)
354+
self._test_op(
355+
ModelWithScale(scale_factor=1.5, align_corners=True),
356+
(torch.randn(2, 3, 5, 5),),
357+
flow,
358+
generate_random_test_inputs=False,
359+
)

0 commit comments

Comments
 (0)