Skip to content

Commit 2b7a99e

Browse files
committed
Merge branch 'feat/adjust_hue_cvcuda' into feat/brightness_contrast_cvcuda
2 parents 5e07d08 + 86881a4 commit 2b7a99e

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

test/test_transforms_v2.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6149,7 +6149,18 @@ def test_kernel_image(self, dtype, device):
61496149
def test_kernel_video(self):
61506150
check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25)
61516151

6152-
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
6152+
@pytest.mark.parametrize(
6153+
"make_input",
6154+
[
6155+
make_image_tensor,
6156+
make_image,
6157+
make_image_pil,
6158+
make_video,
6159+
pytest.param(
6160+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6161+
),
6162+
],
6163+
)
61536164
def test_functional(self, make_input):
61546165
check_functional(F.adjust_hue, make_input(), hue_factor=0.25)
61556166

@@ -6160,9 +6171,16 @@ def test_functional(self, make_input):
61606171
(F._color._adjust_hue_image_pil, PIL.Image.Image),
61616172
(F.adjust_hue_image, tv_tensors.Image),
61626173
(F.adjust_hue_video, tv_tensors.Video),
6174+
pytest.param(
6175+
F._color._adjust_hue_cvcuda,
6176+
"cvcuda.Tensor",
6177+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available"),
6178+
),
61636179
],
61646180
)
61656181
def test_functional_signature(self, kernel, input_type):
6182+
if input_type == "cvcuda.Tensor":
6183+
input_type = _import_cvcuda().Tensor
61666184
check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type)
61676185

61686186
def test_functional_error(self):
@@ -6173,11 +6191,27 @@ def test_functional_error(self):
61736191
with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")):
61746192
F.adjust_hue(make_image(), hue_factor=hue_factor)
61756193

6194+
@pytest.mark.parametrize(
6195+
"make_input",
6196+
[
6197+
make_image,
6198+
pytest.param(
6199+
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA not available")
6200+
),
6201+
],
6202+
)
61766203
@pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5])
6177-
def test_correctness_image(self, hue_factor):
6178-
image = make_image(dtype=torch.uint8, device="cpu")
6204+
def test_correctness_image(self, make_input, hue_factor):
6205+
image = make_input(dtype=torch.uint8, device="cpu")
61796206

61806207
actual = F.adjust_hue(image, hue_factor=hue_factor)
6208+
6209+
if make_input is make_image_cvcuda:
6210+
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
6211+
actual = actual.squeeze(0)
6212+
image = F.cvcuda_to_tensor(image)
6213+
image = image.squeeze(0)
6214+
61816215
expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor))
61826216

61836217
mae = (actual.float() - expected.float()).abs().mean()

torchvision/transforms/v2/functional/_color.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,31 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
457457
return adjust_hue_image(video, hue_factor=hue_factor)
458458

459459

460+
def _adjust_hue_cvcuda(image: "cvcuda.Tensor", hue_factor: float) -> "cvcuda.Tensor":
461+
cvcuda = _import_cvcuda()
462+
463+
if not (-0.5 <= hue_factor <= 0.5):
464+
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
465+
466+
c = image.shape[3]
467+
if c not in [1, 3, 4]:
468+
raise TypeError(f"Input image tensor permitted channel values are 1, 3, or 4, but found {c}")
469+
470+
if c == 1: # Match PIL behaviour
471+
return image
472+
473+
# no native adjust_hue, use CV-CUDA for color converison, use torch for elementwise operations
474+
hsv = cvcuda.cvtcolor(image, cvcuda.ColorConversion.RGB2HSV)
475+
hsv_torch = torch.as_tensor(hsv.cuda()).float()
476+
hsv_torch[..., 0] = (hsv_torch[..., 0] + hue_factor * 180) % 180
477+
hsv_modified = cvcuda.as_tensor(hsv_torch.to(torch.uint8), "NHWC")
478+
return cvcuda.cvtcolor(hsv_modified, cvcuda.ColorConversion.HSV2RGB)
479+
480+
481+
if CVCUDA_AVAILABLE:
482+
_register_kernel_internal(adjust_hue, _import_cvcuda().Tensor)(_adjust_hue_cvcuda)
483+
484+
460485
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
461486
"""Adjust gamma."""
462487
if torch.jit.is_scripting():

0 commit comments

Comments
 (0)