Skip to content

Commit 18922e3

Browse files
committed
handle some comments from other prs review
1 parent 3582c58 commit 18922e3

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

test/test_transforms_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3526,12 +3526,14 @@ def test_functional(self, make_input):
35263526
(F.crop_keypoints, tv_tensors.KeyPoints),
35273527
pytest.param(
35283528
F._geometry._crop_cvcuda,
3529-
_import_cvcuda().Tensor,
3529+
"cvcuda.Tensor",
35303530
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
35313531
),
35323532
],
35333533
)
35343534
def test_functional_signature(self, kernel, input_type):
3535+
if input_type == "cvcuda.Tensor":
3536+
input_type = _import_cvcuda().Tensor
35353537
check_functional_kernel_signature_match(F.crop, kernel=kernel, input_type=input_type)
35363538

35373539
@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
@@ -4985,12 +4987,14 @@ def test_functional(self, make_input):
49854987
(F.center_crop_keypoints, tv_tensors.KeyPoints),
49864988
pytest.param(
49874989
F._geometry._center_crop_cvcuda,
4988-
_import_cvcuda().Tensor,
4990+
"cvcuda.Tensor",
49894991
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
49904992
),
49914993
],
49924994
)
49934995
def test_functional_signature(self, kernel, input_type):
4996+
if input_type == "cvcuda.Tensor":
4997+
input_type = _import_cvcuda().Tensor
49944998
check_functional_kernel_signature_match(F.center_crop, kernel=kernel, input_type=input_type)
49954999

49965000
@pytest.mark.parametrize(

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,7 @@ def _crop_cvcuda(
19651965

19661966

19671967
if CVCUDA_AVAILABLE:
1968-
_crop_cvcuda_registered = _register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda)
1968+
_register_kernel_internal(crop, _import_cvcuda().Tensor)(_crop_cvcuda)
19691969

19701970

19711971
def perspective(
@@ -2754,9 +2754,7 @@ def _center_crop_cvcuda(
27542754

27552755

27562756
if CVCUDA_AVAILABLE:
2757-
_center_crop_cvcuda_registered = _register_kernel_internal(center_crop, _import_cvcuda().Tensor)(
2758-
_center_crop_cvcuda
2759-
)
2757+
_register_kernel_internal(center_crop, _import_cvcuda().Tensor)(_center_crop_cvcuda)
27602758

27612759

27622760
def resized_crop(

0 commit comments

Comments
 (0)