Skip to content

Commit 9b721ef

Browse files
committed
simplify test for center crop cvcuda
1 parent 37a91e0 commit 9b721ef

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

test/test_transforms_v2.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5032,25 +5032,30 @@ def test_transform(self, make_input):
50325032
check_transform(transforms.CenterCrop(self.OUTPUT_SIZES[0]), make_input(self.INPUT_SIZE))
50335033

50345034
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
5035+
@pytest.mark.parametrize(
5036+
"make_input",
5037+
[
5038+
make_image,
5039+
pytest.param(
5040+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5041+
),
5042+
],
5043+
)
50355044
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
5036-
def test_image_correctness(self, output_size, fn):
5037-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
5045+
def test_image_correctness(self, output_size, make_input, fn):
5046+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
50385047

50395048
actual = fn(image, output_size)
5040-
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
50415049

5042-
assert_equal(actual, expected)
5043-
5044-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5045-
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
5046-
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
5047-
def test_cvcuda_correctness(self, output_size, fn):
5048-
image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda")
5050+
if make_input == make_image_cvcuda:
5051+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5052+
actual = actual.squeeze(0)
5053+
image = F.cvcuda_to_tensor(image).to(device="cpu")
5054+
image = image.squeeze(0)
50495055

5050-
actual = fn(image, output_size)
5051-
expected = F.center_crop(F.cvcuda_to_tensor(image), output_size)
5056+
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
50525057

5053-
assert_equal(F.cvcuda_to_tensor(actual), expected)
5058+
assert_equal(actual, expected)
50545059

50555060
def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
50565061
image_height, image_width = bounding_boxes.canvas_size

0 commit comments

Comments
 (0)