Skip to content

Commit ed2bd35

Browse files
committed
fix: crop testing, adhere to conventions
1 parent 2219ee5 commit ed2bd35

File tree

5 files changed

+51
-31
lines changed

5 files changed

+51
-31
lines changed

test/test_transforms_v2.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3506,6 +3506,9 @@ def test_kernel_video(self):
35063506
make_segmentation_mask,
35073507
make_video,
35083508
make_keypoints,
3509+
pytest.param(
3510+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
3511+
),
35093512
],
35103513
)
35113514
def test_functional(self, make_input):
@@ -3521,6 +3524,11 @@ def test_functional(self, make_input):
35213524
(F.crop_mask, tv_tensors.Mask),
35223525
(F.crop_video, tv_tensors.Video),
35233526
(F.crop_keypoints, tv_tensors.KeyPoints),
3527+
pytest.param(
3528+
F._geometry._crop_cvcuda,
3529+
_import_cvcuda().Tensor,
3530+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
3531+
),
35243532
],
35253533
)
35263534
def test_functional_signature(self, kernel, input_type):
@@ -3549,15 +3557,18 @@ def test_functional_image_correctness(self, kwargs):
35493557
make_segmentation_mask,
35503558
make_video,
35513559
make_keypoints,
3560+
pytest.param(
3561+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
3562+
),
35523563
],
35533564
)
35543565
def test_transform(self, param, value, make_input):
3555-
input = make_input(self.INPUT_SIZE)
3566+
input_data = make_input(self.INPUT_SIZE)
35563567

35573568
check_sample_input = True
35583569
if param == "fill":
35593570
if isinstance(value, (tuple, list)):
3560-
if isinstance(input, tv_tensors.Mask):
3571+
if isinstance(input_data, tv_tensors.Mask):
35613572
pytest.skip("F.pad_mask doesn't support non-scalar fill.")
35623573
else:
35633574
check_sample_input = False
@@ -3566,14 +3577,14 @@ def test_transform(self, param, value, make_input):
35663577
# 1. size is required
35673578
# 2. the fill parameter only has an affect if we need padding
35683579
size=[s + 4 for s in self.INPUT_SIZE],
3569-
fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
3580+
fill=adapt_fill(value, dtype=input_data.dtype if isinstance(input_data, torch.Tensor) else torch.uint8),
35703581
)
35713582
else:
35723583
kwargs = {param: value}
35733584

35743585
check_transform(
35753586
transforms.RandomCrop(**kwargs, pad_if_needed=True),
3576-
input,
3587+
input_data,
35773588
check_v1_compatibility=param != "fill" or isinstance(value, (int, float)),
35783589
check_sample_input=check_sample_input,
35793590
)
@@ -3637,6 +3648,31 @@ def test_transform_image_correctness(self, param, value, seed):
36373648

36383649
assert_equal(actual, expected)
36393650

3651+
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
3652+
@pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15), (10, 10)])
3653+
@pytest.mark.parametrize("seed", list(range(5)))
3654+
def test_transform_cvcuda_correctness(self, size, seed):
3655+
pad_if_needed = False
3656+
if size[0] > self.INPUT_SIZE[0] or size[1] > self.INPUT_SIZE[1]:
3657+
pad_if_needed = True
3658+
transform = transforms.RandomCrop(size, pad_if_needed=pad_if_needed)
3659+
3660+
image = make_image(size=self.INPUT_SIZE, batch_dims=(1,), device="cuda")
3661+
cv_image = F.to_cvcuda_tensor(image)
3662+
3663+
with freeze_rng_state():
3664+
torch.manual_seed(seed)
3665+
actual = transform(cv_image)
3666+
3667+
torch.manual_seed(seed)
3668+
expected = transform(image)
3669+
3670+
if not pad_if_needed:
3671+
torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=0)
3672+
else:
3673+
# if padding is requied, CV-CUDA will always fill with zeros
3674+
torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=0, atol=get_max_value(image.dtype))
3675+
36403676
def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width):
36413677
affine_matrix = np.array(
36423678
[
@@ -3765,25 +3801,6 @@ def test_errors(self):
37653801
transforms.RandomCrop([10, 12], padding=1, padding_mode="abc")
37663802

37673803

3768-
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="cvcuda not available")
3769-
@needs_cuda
3770-
class TestCropCVCUDA:
3771-
def test_functional(self):
3772-
check_functional(
3773-
F.crop, make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,)), **TestCrop.MINIMAL_CROP_KWARGS
3774-
)
3775-
3776-
def test_functional_signature(self):
3777-
check_functional_kernel_signature_match(F.crop, kernel=F.crop_cvcuda, input_type=cvcuda.Tensor)
3778-
3779-
@pytest.mark.parametrize("size", [(10, 5), (25, 15), (25, 5), (10, 15)])
3780-
def test_functional_correctness(self, size):
3781-
image = make_image_cvcuda(TestCrop.INPUT_SIZE, batch_dims=(1,))
3782-
actual = F.crop(image, 0, 0, *size)
3783-
expected = F.crop(F.cvcuda_to_tensor(image), 0, 0, *size)
3784-
assert_equal(F.cvcuda_to_tensor(actual), expected)
3785-
3786-
37873804
class TestErase:
37883805
INPUT_SIZE = (17, 11)
37893806
FUNCTIONAL_KWARGS = dict(

torchvision/transforms/v2/_transform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel
14+
from .functional._utils import _get_kernel, is_cvcuda_tensor
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
2727

2828
def __init__(self) -> None:
2929
super().__init__()
@@ -90,7 +90,9 @@ def _needs_transform_list(self, flat_inputs: list[Any]) -> list[bool]:
9090
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
9191

9292
needs_transform_list = []
93-
transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
93+
transform_pure_tensor = not has_any(
94+
flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image, is_cvcuda_tensor
95+
)
9496
for inpt in flat_inputs:
9597
needs_transform = True
9698

torchvision/transforms/v2/functional/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
22

3-
from ._utils import is_pure_tensor, register_kernel # usort: skip
3+
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip
44

55
from ._meta import (
66
clamp_bounding_boxes,
@@ -76,14 +76,12 @@
7676
affine_video,
7777
center_crop,
7878
center_crop_bounding_boxes,
79-
center_crop_cvcuda,
8079
center_crop_image,
8180
center_crop_keypoints,
8281
center_crop_mask,
8382
center_crop_video,
8483
crop,
8584
crop_bounding_boxes,
86-
crop_cvcuda,
8785
crop_image,
8886
crop_keypoints,
8987
crop_mask,

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,13 +1924,15 @@ def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int
19241924
return crop_image(video, top, left, height, width)
19251925

19261926

1927-
def crop_cvcuda(
1927+
def _crop_cvcuda(
19281928
image: "cvcuda.Tensor",
19291929
top: int,
19301930
left: int,
19311931
height: int,
19321932
width: int,
19331933
) -> "cvcuda.Tensor":
1934+
cvcuda = _import_cvcuda()
1935+
19341936
image_height, image_width, channels = image.shape[1:]
19351937
top_diff = 0
19361938
left_diff = 0
@@ -1963,7 +1965,7 @@ def crop_cvcuda(
19631965

19641966

19651967
if CVCUDA_AVAILABLE:
1966-
_register_kernel_internal(crop, cvcuda.Tensor)(crop_cvcuda)
1968+
_crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda)
19671969

19681970

19691971
def perspective(

torchvision/transforms/v2/functional/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torchvision import tv_tensors
77

8+
89
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
910
_FillTypeJIT = Optional[list[float]]
1011

0 commit comments

Comments
 (0)