Skip to content

Commit 6a0035d

Browse files
committed
begin work on finalizing the crop PR to include five and ten crop, adhere to new PR reviews for flip
1 parent 9b721ef commit 6a0035d

File tree

5 files changed

+160
-17
lines changed

5 files changed

+160
-17
lines changed

test/common_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,17 @@ def combinations_grid(**kwargs):
276276
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
277277

278278

279+
def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
280+
tensor = cvcuda_to_tensor(tensor)
281+
if tensor.ndim != 4:
282+
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
283+
if tensor.shape[0] != 1:
284+
raise ValueError(
285+
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
286+
)
287+
return tensor.squeeze(0).cpu()
288+
289+
279290
class ImagePair(TensorLikePair):
280291
def __init__(
281292
self,

test/test_transforms_v2.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3677,10 +3677,7 @@ def test_transform_image_correctness(self, param, value, seed, make_input):
36773677
torch.manual_seed(seed)
36783678

36793679
if make_input == make_image_cvcuda:
3680-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
3681-
actual = actual.squeeze(0)
3682-
image = F.cvcuda_to_tensor(image).to(device="cpu")
3683-
image = image.squeeze(0)
3680+
image = cvcuda_to_pil_compatible_tensor(image)
36843681

36853682
expected = F.to_image(transform(F.to_pil_image(image)))
36863683

@@ -5048,10 +5045,7 @@ def test_image_correctness(self, output_size, make_input, fn):
50485045
actual = fn(image, output_size)
50495046

50505047
if make_input == make_image_cvcuda:
5051-
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
5052-
actual = actual.squeeze(0)
5053-
image = F.cvcuda_to_tensor(image).to(device="cpu")
5054-
image = image.squeeze(0)
5048+
image = cvcuda_to_pil_compatible_tensor(image)
50555049

50565050
expected = F.to_image(F.center_crop(F.to_pil_image(image), output_size=output_size))
50575051

@@ -6327,7 +6321,15 @@ def wrapper(*args, **kwargs):
63276321

63286322
@pytest.mark.parametrize(
63296323
"make_input",
6330-
[make_image_tensor, make_image_pil, make_image, make_video],
6324+
[
6325+
make_image_tensor,
6326+
make_image_pil,
6327+
make_image,
6328+
make_video,
6329+
pytest.param(
6330+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6331+
),
6332+
],
63316333
)
63326334
@pytest.mark.parametrize("functional", [F.five_crop, F.ten_crop])
63336335
def test_functional(self, make_input, functional):
@@ -6345,13 +6347,27 @@ def test_functional(self, make_input, functional):
63456347
(F.five_crop, F._geometry._five_crop_image_pil, PIL.Image.Image),
63466348
(F.five_crop, F.five_crop_image, tv_tensors.Image),
63476349
(F.five_crop, F.five_crop_video, tv_tensors.Video),
6350+
pytest.param(
6351+
F.five_crop,
6352+
F._geometry._five_crop_cvcuda,
6353+
"cvcuda.Tensor",
6354+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
6355+
),
63486356
(F.ten_crop, F.ten_crop_image, torch.Tensor),
63496357
(F.ten_crop, F._geometry._ten_crop_image_pil, PIL.Image.Image),
63506358
(F.ten_crop, F.ten_crop_image, tv_tensors.Image),
63516359
(F.ten_crop, F.ten_crop_video, tv_tensors.Video),
6360+
pytest.param(
6361+
F.ten_crop,
6362+
F._geometry._ten_crop_cvcuda,
6363+
"cvcuda.Tensor",
6364+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
6365+
),
63526366
],
63536367
)
63546368
def test_functional_signature(self, functional, kernel, input_type):
6369+
if input_type == "cvcuda.Tensor":
6370+
input_type = _import_cvcuda().Tensor
63556371
check_functional_kernel_signature_match(functional, kernel=kernel, input_type=input_type)
63566372

63576373
class _TransformWrapper(nn.Module):
@@ -6373,7 +6389,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
63736389

63746390
@pytest.mark.parametrize(
63756391
"make_input",
6376-
[make_image_tensor, make_image_pil, make_image, make_video],
6392+
[
6393+
make_image_tensor,
6394+
make_image_pil,
6395+
make_image,
6396+
make_video,
6397+
pytest.param(
6398+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6399+
),
6400+
],
63776401
)
63786402
@pytest.mark.parametrize("transform_cls", [transforms.FiveCrop, transforms.TenCrop])
63796403
def test_transform(self, make_input, transform_cls):
@@ -6391,29 +6415,55 @@ def test_transform_error(self, make_input, transform_cls):
63916415
with pytest.raises(TypeError, match="not supported"):
63926416
transform(make_input(self.INPUT_SIZE))
63936417

6418+
@pytest.mark.parametrize(
6419+
"make_input",
6420+
[
6421+
make_image,
6422+
pytest.param(
6423+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6424+
),
6425+
],
6426+
)
63946427
@pytest.mark.parametrize("fn", [F.five_crop, transform_cls_to_functional(transforms.FiveCrop)])
6395-
def test_correctness_image_five_crop(self, fn):
6396-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
6428+
def test_correctness_image_five_crop(self, make_input, fn):
6429+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
63976430

63986431
actual = fn(image, size=self.OUTPUT_SIZE)
6432+
6433+
if make_input is make_image_cvcuda:
6434+
image = cvcuda_to_pil_compatible_tensor(image)
6435+
63996436
expected = F.five_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE)
64006437

64016438
assert isinstance(actual, tuple)
64026439
assert_equal(actual, [F.to_image(e) for e in expected])
64036440

6441+
@pytest.mark.parametrize(
6442+
"make_input",
6443+
[
6444+
make_image,
6445+
pytest.param(
6446+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6447+
),
6448+
],
6449+
)
64046450
@pytest.mark.parametrize("fn_or_class", [F.ten_crop, transforms.TenCrop])
64056451
@pytest.mark.parametrize("vertical_flip", [False, True])
6406-
def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip):
6452+
def test_correctness_image_ten_crop(self, make_input, fn_or_class, vertical_flip):
64076453
if fn_or_class is transforms.TenCrop:
64086454
fn = transform_cls_to_functional(fn_or_class, size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
64096455
kwargs = dict()
64106456
else:
64116457
fn = fn_or_class
64126458
kwargs = dict(size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
64136459

6414-
image = make_image(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
6460+
image = make_input(self.INPUT_SIZE, dtype=torch.uint8, device="cpu")
64156461

64166462
actual = fn(image, **kwargs)
6463+
6464+
if make_input is make_image_cvcuda:
6465+
image = cvcuda_to_pil_compatible_tensor(image)
6466+
64176467
expected = F.ten_crop(F.to_pil_image(image), size=self.OUTPUT_SIZE, vertical_flip=vertical_flip)
64186468

64196469
assert isinstance(actual, tuple)

torchvision/transforms/v2/_geometry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
get_bounding_boxes,
2727
has_all,
2828
has_any,
29+
is_cvcuda_tensor,
2930
is_pure_tensor,
3031
query_size,
3132
)
@@ -194,6 +195,8 @@ class CenterCrop(Transform):
194195

195196
_v1_transform_cls = _transforms.CenterCrop
196197

198+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
199+
197200
def __init__(self, size: Union[int, Sequence[int]]):
198201
super().__init__()
199202
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
@@ -360,6 +363,8 @@ class FiveCrop(Transform):
360363

361364
_v1_transform_cls = _transforms.FiveCrop
362365

366+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
367+
363368
def __init__(self, size: Union[int, Sequence[int]]) -> None:
364369
super().__init__()
365370
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
@@ -404,6 +409,8 @@ class TenCrop(Transform):
404409

405410
_v1_transform_cls = _transforms.TenCrop
406411

412+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
413+
407414
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
408415
super().__init__()
409416
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
@@ -811,6 +818,8 @@ class RandomCrop(Transform):
811818

812819
_v1_transform_cls = _transforms.RandomCrop
813820

821+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
822+
814823
def _extract_params_for_v1_transform(self) -> dict[str, Any]:
815824
params = super()._extract_params_for_v1_transform()
816825

@@ -1121,6 +1130,8 @@ class RandomIoUCrop(Transform):
11211130
Default, 40.
11221131
"""
11231132

1133+
_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)
1134+
11241135
def __init__(
11251136
self,
11261137
min_scale: float = 0.3,

torchvision/transforms/v2/_transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch import nn
99
from torch.utils._pytree import tree_flatten, tree_unflatten
1010
from torchvision import tv_tensors
11-
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
11+
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
1212
from torchvision.utils import _log_api_usage_once
1313

14-
from .functional._utils import _get_kernel, is_cvcuda_tensor
14+
from .functional._utils import _get_kernel
1515

1616

1717
class Transform(nn.Module):
@@ -23,7 +23,7 @@ class Transform(nn.Module):
2323

2424
# Class attribute defining transformed types. Other types are passed-through without any transformation
2525
# We support both Types and callables that are able to do further checks on the type of the input.
26-
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)
26+
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
2727

2828
def __init__(self) -> None:
2929
super().__init__()

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
147147
return horizontal_flip_image(video)
148148

149149

150+
def _horizontal_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
151+
return _import_cvcuda().flip(image, flipCode=1)
152+
153+
154+
if CVCUDA_AVAILABLE:
155+
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_cvcuda)
156+
157+
150158
def vertical_flip(inpt: torch.Tensor) -> torch.Tensor:
151159
"""See :class:`~torchvision.transforms.v2.RandomVerticalFlip` for details."""
152160
if torch.jit.is_scripting():
@@ -243,6 +251,14 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
243251
return vertical_flip_image(video)
244252

245253

254+
def _vertical_flip_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
255+
return _import_cvcuda().flip(image, flipCode=0)
256+
257+
258+
if CVCUDA_AVAILABLE:
259+
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_cvcuda)
260+
261+
246262
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
247263
# prevalent and well understood. Thus, we just alias them without deprecating the old names.
248264
hflip = horizontal_flip
@@ -3016,6 +3032,29 @@ def five_crop_video(
30163032
return five_crop_image(video, size)
30173033

30183034

3035+
def _five_crop_cvcuda(
3036+
image: "cvcuda.Tensor",
3037+
size: list[int],
3038+
) -> tuple["cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor", "cvcuda.Tensor"]:
3039+
crop_height, crop_width = _parse_five_crop_size(size)
3040+
image_height, image_width = image.shape[-2:]
3041+
3042+
if crop_width > image_width or crop_height > image_height:
3043+
raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}")
3044+
3045+
tl = _crop_cvcuda(image, 0, 0, crop_height, crop_width)
3046+
tr = _crop_cvcuda(image, 0, image_width - crop_height, crop_width, crop_height)
3047+
bl = _crop_cvcuda(image, image_height - crop_height, 0, crop_width, crop_height)
3048+
br = _crop_cvcuda(image, image_height - crop_height, image_width - crop_width, crop_width, crop_height)
3049+
center = _center_crop_cvcuda(image, [crop_height, crop_width])
3050+
3051+
return tl, tr, bl, br, center
3052+
3053+
3054+
if CVCUDA_AVAILABLE:
3055+
_register_kernel_internal(five_crop, _import_cvcuda().Tensor)(_five_crop_cvcuda)
3056+
3057+
30193058
def ten_crop(
30203059
inpt: torch.Tensor, size: list[int], vertical_flip: bool = False
30213060
) -> tuple[
@@ -3111,3 +3150,35 @@ def ten_crop_video(
31113150
torch.Tensor,
31123151
]:
31133152
return ten_crop_image(video, size, vertical_flip=vertical_flip)
3153+
3154+
3155+
def _ten_crop_cvcuda(
3156+
image: "cvcuda.Tensor",
3157+
size: list[int],
3158+
vertical_flip: bool = False,
3159+
) -> tuple[
3160+
"cvcuda.Tensor",
3161+
"cvcuda.Tensor",
3162+
"cvcuda.Tensor",
3163+
"cvcuda.Tensor",
3164+
"cvcuda.Tensor",
3165+
"cvcuda.Tensor",
3166+
"cvcuda.Tensor",
3167+
"cvcuda.Tensor",
3168+
"cvcuda.Tensor",
3169+
"cvcuda.Tensor",
3170+
]:
3171+
non_flipped = _five_crop_cvcuda(image, size)
3172+
3173+
if vertical_flip:
3174+
image = _vertical_flip_cvcuda(image)
3175+
else:
3176+
image = _horizontal_flip_cvcuda(image)
3177+
3178+
flipped = _five_crop_cvcuda(image, size)
3179+
3180+
return non_flipped + flipped
3181+
3182+
3183+
if CVCUDA_AVAILABLE:
3184+
_register_kernel_internal(ten_crop, _import_cvcuda().Tensor)(_ten_crop_cvcuda)

0 commit comments

Comments
 (0)