From aac1684b4907782919a77eff7534025ef095f2a6 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 11 Dec 2025 17:22:57 +0100 Subject: [PATCH 1/2] Avoid torch.compile graph-break in functional normalize fn The data-dependent check on the standard deviation (in case any of its elements is zero) caused a graph break when using torch.compile. Instead this can be replaced by an in-graph assert, which can even become a on-device assert in order to support CUDA graphs. --- torchvision/transforms/_functional_tensor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 71409c40af3..8ef402ea62b 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -919,8 +919,9 @@ def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool dtype = tensor.dtype mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) - if (std == 0).any(): - raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + std = torch.ops.aten._functional_assert_async.msg( + (std == 0).any(), f"std evaluated to zero after conversion to {dtype}, leading to division by zero.", std + ) if mean.ndim == 1: mean = mean.view(-1, 1, 1) if std.ndim == 1: From 89c1fd6cd44d48ddf4da4b69a401a45bd477c6b8 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 11 Dec 2025 17:23:23 +0000 Subject: [PATCH 2/2] Update --- torchvision/transforms/_functional_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 8ef402ea62b..c0d31a244d6 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -920,7 +920,7 @@ def normalize(tensor: Tensor, mean: list[float], std: list[float], inplace: bool mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) std = torch.ops.aten._functional_assert_async.msg( - (std == 0).any(), f"std evaluated to zero after conversion to {dtype}, leading to division by zero.", std + (std != 0).all(), f"std evaluated to zero after conversion to {dtype}, leading to division by zero.", std ) if mean.ndim == 1: mean = mean.view(-1, 1, 1)