diff --git a/monai/data/utils.py b/monai/data/utils.py index ca7d5c9d9e..4e5a3bd7f6 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -597,11 +597,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) ): return batch + # if scalar tensor/array, return the item itself. + if getattr(batch, "ndim", -1) == 0 and hasattr(batch, "item"): + return batch.item() if detach else batch if isinstance(batch, torch.Tensor): if detach: batch = batch.detach() - if batch.ndim == 0: - return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) # if of type MetaObj, decollate the metadata if isinstance(batch, MetaObj):