Skip to content

Commit 16d94ca

Browse files
committed
wip
1 parent 0c9a9c0 commit 16d94ca

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,30 @@ def rotate_video(
15231523
return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
15241524

15251525

1526+
if CVCUDA_AVAILABLE:
1527+
_cvcuda_interp = {
1528+
InterpolationMode.BILINEAR: cvcuda.Interp.LINEAR,
1529+
"bilinear": cvcuda.Interp.LINEAR,
1530+
"linear": cvcuda.Interp.LINEAR,
1531+
2: cvcuda.Interp.LINEAR,
1532+
InterpolationMode.BICUBIC: cvcuda.Interp.CUBIC,
1533+
"bicubic": cvcuda.Interp.CUBIC,
1534+
3: cvcuda.Interp.CUBIC,
1535+
InterpolationMode.NEAREST: cvcuda.Interp.NEAREST,
1536+
"nearest": cvcuda.Interp.NEAREST,
1537+
0: cvcuda.Interp.NEAREST,
1538+
InterpolationMode.BOX: cvcuda.Interp.BOX,
1539+
"box": cvcuda.Interp.BOX,
1540+
4: cvcuda.Interp.BOX,
1541+
InterpolationMode.HAMMING: cvcuda.Interp.HAMMING,
1542+
"hamming": cvcuda.Interp.HAMMING,
1543+
5: cvcuda.Interp.HAMMING,
1544+
InterpolationMode.LANCZOS: cvcuda.Interp.LANCZOS,
1545+
"lanczos": cvcuda.Interp.LANCZOS,
1546+
1: cvcuda.Interp.LANCZOS,
1547+
}
1548+
1549+
15261550
def _rotate_cvcuda(
15271551
inpt: "cvcuda.Tensor",
15281552
angle: float,
@@ -1532,11 +1556,16 @@ def _rotate_cvcuda(
15321556
fill: _FillTypeJIT = None,
15331557
) -> "cvcuda.Tensor":
15341558
cvcuda = _import_cvcuda()
1535-
return cvcuda.rotate(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
15361559

1560+
interp = _cvcuda_interp.get(interpolation)
1561+
if interp is None:
1562+
raise ValueError(f"Interpolation mode {interpolation} is not supported with CV-CUDA")
1563+
1564+
return cvcuda.rotate(inpt, angle_deg=angle, shift=(0.0, 0.0), interpolation=interpolation)
15371565

1538-
if _CVCUDA_AVAILABLE:
1539-
_register_kernel_internal(rotate, _import_cvcuda().Tensor)(rotate_cvcuda)
1566+
1567+
if CVCUDA_AVAILABLE:
1568+
_register_kernel_internal(rotate, _import_cvcuda().Tensor)(_rotate_cvcuda)
15401569

15411570

15421571
def pad(

0 commit comments

Comments
 (0)