Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
44db71c
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
e3dd700
update make_image_cvcuda to have default batch dim
justincdavis Nov 25, 2025
c035df1
add stanardized setup to main for easier updating of PRs and branches
justincdavis Dec 2, 2025
98d7dfb
update is_cvcuda_tensor
justincdavis Dec 2, 2025
ddc116d
add cvcuda to pil compatible to transforms by default
justincdavis Dec 2, 2025
e51dc7e
remove cvcuda from transform class
justincdavis Dec 2, 2025
e14e210
merge with main
justincdavis Dec 4, 2025
4939355
resolve more formatting naming
justincdavis Dec 4, 2025
fbea584
update is cvcuda tensor impl
justincdavis Dec 4, 2025
511c169
adjust brightness done and tested
justincdavis Nov 25, 2025
54f3f4a
completed and tested adjust_contrast
justincdavis Nov 26, 2025
b11c38a
update brightness contrast tests plus add comment on mean calc for co…
justincdavis Dec 2, 2025
d379658
complete and tested adjust_hue
justincdavis Nov 26, 2025
310982c
merge brightness contrast and hue adjustment together
justincdavis Dec 2, 2025
e0392a0
wip adjust_saturation
justincdavis Nov 25, 2025
61b237c
adjust saturation complete and tested
justincdavis Nov 25, 2025
2c68fc3
add adjust saturation
justincdavis Dec 2, 2025
8564f0a
update to main standards
justincdavis Dec 4, 2025
aa4c6e7
add colorjitter transform support
justincdavis Dec 13, 2025
9d9515a
add tests
justincdavis Dec 13, 2025
2b7d1f3
Merge remote-tracking branch 'upstream/main' into brightness_contrast…
justincdavis Dec 19, 2025
13beb91
merge with main
justincdavis Dec 19, 2025
f010ff7
drop global import
justincdavis Dec 19, 2025
13f6e7c
fix cvcuda undefined in adjust saturation
justincdavis Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 181 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -42,7 +43,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -2814,7 +2814,19 @@ class TestAdjustBrightness:
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image_pil, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_brightness, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)

Expand All @@ -2825,19 +2837,43 @@ def test_functional(self, make_input):
(F._color._adjust_brightness_image_pil, PIL.Image.Image),
(F.adjust_brightness_image, tv_tensors.Image),
(F.adjust_brightness_video, tv_tensors.Video),
pytest.param(
F._color._adjust_brightness_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_brightness_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
@pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
def test_image_correctness(self, brightness_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_image_correctness(self, make_input, brightness_factor):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.adjust_brightness(image, brightness_factor=brightness_factor)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
assert_close(actual, expected, rtol=0, atol=1)
else:
assert_close(actual, expected)


class TestCutMixMixUp:
Expand Down Expand Up @@ -6045,7 +6081,19 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.adjust_contrast_video, make_video(), contrast_factor=0.5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_contrast, make_input(), contrast_factor=0.5)

Expand All @@ -6056,9 +6104,16 @@ def test_functional(self, make_input):
(F._color._adjust_contrast_image_pil, PIL.Image.Image),
(F.adjust_contrast_image, tv_tensors.Image),
(F.adjust_contrast_video, tv_tensors.Video),
pytest.param(
F._color._adjust_contrast_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_contrast_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_contrast, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -6068,11 +6123,25 @@ def test_functional_error(self):
with pytest.raises(ValueError, match="is not non-negative"):
F.adjust_contrast(make_image(), contrast_factor=-1)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
@pytest.mark.parametrize("contrast_factor", [0.1, 0.5, 1.0])
def test_correctness_image(self, contrast_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, contrast_factor):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.adjust_contrast(image, contrast_factor=contrast_factor)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_contrast(F.to_pil_image(image), contrast_factor=contrast_factor))

assert_close(actual, expected, rtol=0, atol=1)
Expand Down Expand Up @@ -6127,7 +6196,19 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.adjust_hue_video, make_video(), hue_factor=0.25)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_hue, make_input(), hue_factor=0.25)

Expand All @@ -6138,9 +6219,16 @@ def test_functional(self, make_input):
(F._color._adjust_hue_image_pil, PIL.Image.Image),
(F.adjust_hue_image, tv_tensors.Image),
(F.adjust_hue_video, tv_tensors.Video),
pytest.param(
F._color._adjust_hue_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_hue_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_hue, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -6151,11 +6239,26 @@ def test_functional_error(self):
with pytest.raises(ValueError, match=re.escape("is not in [-0.5, 0.5]")):
F.adjust_hue(make_image(), hue_factor=hue_factor)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
@pytest.mark.parametrize("hue_factor", [-0.5, -0.3, 0.0, 0.2, 0.5])
def test_correctness_image(self, hue_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, hue_factor):
image = make_input(dtype=torch.uint8, device="cpu")

actual = F.adjust_hue(image, hue_factor=hue_factor)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_hue(F.to_pil_image(image), hue_factor=hue_factor))

mae = (actual.float() - expected.float()).abs().mean()
Expand All @@ -6171,7 +6274,19 @@ def test_kernel_image(self, dtype, device):
def test_kernel_video(self):
check_kernel(F.adjust_saturation_video, make_video(), saturation_factor=0.5)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_image_pil, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_image_pil,
make_video,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional(self, make_input):
check_functional(F.adjust_saturation, make_input(), saturation_factor=0.5)

Expand All @@ -6182,9 +6297,16 @@ def test_functional(self, make_input):
(F._color._adjust_saturation_image_pil, PIL.Image.Image),
(F.adjust_saturation_image, tv_tensors.Image),
(F.adjust_saturation_video, tv_tensors.Video),
pytest.param(
F._color._adjust_saturation_image_cvcuda,
None,
marks=pytest.mark.needs_cvcuda,
),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._color._adjust_saturation_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.adjust_saturation, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -6194,11 +6316,26 @@ def test_functional_error(self):
with pytest.raises(ValueError, match="is not non-negative"):
F.adjust_saturation(make_image(), saturation_factor=-1)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("saturation_factor", [0.1, 0.5, 1.0])
def test_correctness_image(self, saturation_factor):
image = make_image(dtype=torch.uint8, device="cpu")
def test_correctness_image(self, make_input, color_space, saturation_factor):
image = make_input(dtype=torch.uint8, color_space=color_space, device="cpu")

actual = F.adjust_saturation(image, saturation_factor=saturation_factor)

if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()

expected = F.to_image(F.adjust_saturation(F.to_pil_image(image), saturation_factor=saturation_factor))

assert_close(actual, expected, rtol=0, atol=1)
Expand Down Expand Up @@ -6331,7 +6468,16 @@ def test_correctness_image_ten_crop(self, fn_or_class, vertical_flip):
class TestColorJitter:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
[
make_image_tensor,
make_image_pil,
make_image,
make_video,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32])
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand Down Expand Up @@ -6375,24 +6521,41 @@ def test_transform_error(self):
with pytest.raises(ValueError, match="values should be between"):
transforms.ColorJitter(hue=1)

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.needs_cvcuda,
),
],
)
@pytest.mark.parametrize("brightness", [None, 0.1, (0.2, 0.3)])
@pytest.mark.parametrize("contrast", [None, 0.4, (0.5, 0.6)])
@pytest.mark.parametrize("saturation", [None, 0.7, (0.8, 0.9)])
@pytest.mark.parametrize("hue", [None, 0.3, (-0.1, 0.2)])
def test_transform_correctness(self, brightness, contrast, saturation, hue):
image = make_image(dtype=torch.uint8, device="cpu")
def test_transform_correctness(self, make_input, brightness, contrast, saturation, hue):
image = make_input(dtype=torch.uint8, device="cpu")

transform = transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)

with freeze_rng_state():
torch.manual_seed(0)
actual = transform(image)

if make_input is make_image_cvcuda:
actual = F.cvcuda_to_tensor(actual)[0].cpu()
image = F.cvcuda_to_tensor(image)[0].cpu()

torch.manual_seed(0)
expected = F.to_image(transform(F.to_pil_image(image)))

mae = (actual.float() - expected.float()).abs().mean()
assert mae < 2
mae_threshold = 2
if make_input is make_image_cvcuda:
mae_threshold = 3
assert mae < mae_threshold, f"MAE: {mae}"


class TestRgbToGrayscale:
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_tensor

from ._transform import _RandomApplyTransform
from ._utils import query_chw
Expand Down Expand Up @@ -96,6 +97,8 @@ class ColorJitter(Transform):

_v1_transform_cls = _transforms.ColorJitter

_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}

Expand Down
Loading