Skip to content

Commit 163e770

Browse files
committed
support decollate for numpy scalars
fix linter Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com> fix numpy decollate multi arrays Signed-off-by: Arthur Dujardin <arthurdujardin.dev@gmail.com>
1 parent c3a317d commit 163e770

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

monai/data/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,13 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
644644
if out_list[0].ndim == 0 and detach:
645645
return [t.item() for t in out_list]
646646
return list(out_list)
647+
if isinstance(batch, np.ndarray):
648+
if batch.ndim == 0:
649+
return batch.item()
650+
out_list = list(batch)
651+
if out_list[0].ndim == 0 and detach:
652+
return [t.item() for t in out_list]
653+
return out_list
647654

648655
b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)
649656
if b <= 0: # all non-iterable, single item "batch"? {"image": 1, "label": 1}

0 commit comments

Comments
 (0)