Skip to content

Commit 3582c58

Browse files
committed
Fix: update center crop
1 parent ed2bd35 commit 3582c58

File tree

2 files changed

+26
-29
lines changed

2 files changed

+26
-29
lines changed

test/test_transforms_v2.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4965,6 +4965,9 @@ def test_kernel_video(self):
49654965
make_segmentation_mask,
49664966
make_video,
49674967
make_keypoints,
4968+
pytest.param(
4969+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
4970+
),
49684971
],
49694972
)
49704973
def test_functional(self, make_input):
@@ -4980,6 +4983,11 @@ def test_functional(self, make_input):
49804983
(F.center_crop_mask, tv_tensors.Mask),
49814984
(F.center_crop_video, tv_tensors.Video),
49824985
(F.center_crop_keypoints, tv_tensors.KeyPoints),
4986+
pytest.param(
4987+
F._geometry._center_crop_cvcuda,
4988+
_import_cvcuda().Tensor,
4989+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
4990+
),
49834991
],
49844992
)
49854993
def test_functional_signature(self, kernel, input_type):
@@ -4995,6 +5003,9 @@ def test_functional_signature(self, kernel, input_type):
49955003
make_segmentation_mask,
49965004
make_video,
49975005
make_keypoints,
5006+
pytest.param(
5007+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5008+
),
49985009
],
49995010
)
50005011
def test_transform(self, make_input):
@@ -5010,6 +5021,17 @@ def test_image_correctness(self, output_size, fn):
50105021

50115022
assert_equal(actual, expected)
50125023

5024+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
5025+
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
5026+
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
5027+
def test_cvcuda_correctness(self, output_size, fn):
5028+
image = make_image_cvcuda(self.INPUT_SIZE, dtype=torch.uint8, device="cuda")
5029+
5030+
actual = fn(image, output_size)
5031+
expected = F.center_crop(F.cvcuda_to_tensor(image), output_size)
5032+
5033+
assert_equal(F.cvcuda_to_tensor(actual), expected)
5034+
50135035
def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
50145036
image_height, image_width = bounding_boxes.canvas_size
50155037
if isinstance(output_size, int):
@@ -5081,33 +5103,6 @@ def test_keypoints_correctness(self, output_size, dtype, device, fn):
50815103
assert_equal(actual, expected)
50825104

50835105

5084-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available")
5085-
@needs_cuda
5086-
class TestCenterCropCVCUDA:
5087-
def test_functional(self):
5088-
check_functional(
5089-
F.center_crop,
5090-
make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)),
5091-
output_size=TestCenterCrop.OUTPUT_SIZES[0],
5092-
)
5093-
5094-
def test_functional_signature(self):
5095-
check_functional_kernel_signature_match(F.center_crop, kernel=F.center_crop_cvcuda, input_type=cvcuda.Tensor)
5096-
5097-
@pytest.mark.parametrize("output_size", TestCenterCrop.OUTPUT_SIZES)
5098-
def test_functional_correctness(self, output_size):
5099-
image = make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,))
5100-
actual = F.center_crop(image, output_size)
5101-
expected = F.center_crop(F.cvcuda_to_tensor(image), output_size)
5102-
assert_equal(F.cvcuda_to_tensor(actual), expected)
5103-
5104-
def test_transform(self):
5105-
check_transform(
5106-
transforms.CenterCrop(TestCenterCrop.OUTPUT_SIZES[0]),
5107-
make_image_cvcuda(TestCenterCrop.INPUT_SIZE, batch_dims=(1,)),
5108-
)
5109-
5110-
51115106
class TestPerspective:
51125107
COEFFICIENTS = [
51135108
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,7 +2718,7 @@ def center_crop_video(video: torch.Tensor, output_size: list[int]) -> torch.Tens
27182718
return center_crop_image(video, output_size)
27192719

27202720

2721-
def center_crop_cvcuda(
2721+
def _center_crop_cvcuda(
27222722
image: "cvcuda.Tensor",
27232723
output_size: list[int],
27242724
) -> "cvcuda.Tensor":
@@ -2754,7 +2754,9 @@ def center_crop_cvcuda(
27542754

27552755

27562756
if CVCUDA_AVAILABLE:
2757-
_register_kernel_internal(center_crop, cvcuda.Tensor)(center_crop_cvcuda)
2757+
_center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(
2758+
_center_crop_cvcuda
2759+
)
27582760

27592761

27602762
def resized_crop(

0 commit comments

Comments
 (0)