diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d2585daf63..2d5d452dc0 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -37,6 +37,11 @@ Generic Interfaces .. autoclass:: MultiSampleTrait :members: +`ReduceTrait` +^^^^^^^^^^^^^^^^^^ +.. autoclass:: ReduceTrait + :members: + `Randomizable` ^^^^^^^^^^^^^^ .. autoclass:: Randomizable @@ -1252,6 +1257,12 @@ Utility :members: :special-members: __call__ +`FlattenSequence` +"""""""""""""""""""""""" +.. autoclass:: FlattenSequence + :members: + :special-members: __call__ + Dictionary Transforms --------------------- @@ -2337,6 +2348,12 @@ Utility (Dict) :members: :special-members: __call__ +`FlattenSequenced` +""""""""""""""""""""""""" +.. autoclass:: FlattenSequenced + :members: + :special-members: __call__ + MetaTensor ^^^^^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d15042181b..0ab9fe63d5 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -506,7 +506,7 @@ ZoomDict, ) from .spatial.functional import spatial_resample -from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe +from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform from .utility.array import ( AddCoordinateChannels, @@ -521,6 +521,7 @@ EnsureChannelFirst, EnsureType, FgBgToIndices, + FlattenSequence, Identity, ImageFilter, IntensityStats, @@ -593,6 +594,9 @@ FgBgToIndicesd, FgBgToIndicesD, FgBgToIndicesDict, + FlattenSequenced, + FlattenSequenceD, + FlattenSequenceDict, FlattenSubKeysd, FlattenSubKeysD, FlattenSubKeysDict, diff --git a/monai/transforms/traits.py b/monai/transforms/traits.py index 016effc59d..45d081f2e6 100644 --- a/monai/transforms/traits.py +++ b/monai/transforms/traits.py @@ -14,7 +14,7 @@ from __future__ import annotations -__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"] +__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe", "ReduceTrait"] from typing import Any @@ -99,3 +99,14 @@ class ThreadUnsafe: """ pass + + +class ReduceTrait: + """ + An interface to indicate that the transform has the capability to reduce multiple samples + into a single sample. + This interface can be extended from by people adapting transforms to the MONAI framework as well + as by implementors of MONAI transforms. + """ + + pass diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 1a365b8d8e..1eedc7c333 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -25,7 +25,7 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor -from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe +from monai.transforms.traits import LazyTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars @@ -142,7 +142,7 @@ def apply_transform( """ try: map_items_ = int(map_items) if isinstance(map_items, bool) else map_items - if isinstance(data, (list, tuple)) and map_items_ > 0: + if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait): return [ apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides) for item in data @@ -482,8 +482,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError( - f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data" - " and allow_missing_keys==False." + f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False." ) def first_key(self, data: dict[Hashable, Any]): diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 18a0f7f32f..2ac37f2f81 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -43,7 +43,7 @@ median_filter, ) from monai.transforms.inverse import InvertibleTransform, TraceableTransform -from monai.transforms.traits import MultiSampleTrait +from monai.transforms.traits import MultiSampleTrait, ReduceTrait from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform from monai.transforms.utils import ( apply_affine_to_points, @@ -110,6 +110,7 @@ "ImageFilter", "RandImageFilter", "ApplyTransformToPoints", + "FlattenSequence", ] @@ -1950,3 +1951,39 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"]) return data + + +class FlattenSequence(Transform, ReduceTrait): + """ + Flatten a nested sequence (list or tuple) by one level. + If the input is a sequence of sequences, it will flatten them into a single sequence. + Non-nested sequences and other data types are returned unchanged. + + For example: + + .. code-block:: python + + flatten = FlattenSequence() + data = [[1, 2], [3, 4], [5, 6]] + print(flatten(data)) + [1, 2, 3, 4, 5, 6] + + """ + + def __init__(self): + super().__init__() + + def __call__(self, data: list | tuple | Any) -> list | tuple | Any: + """ + Flatten a nested sequence by one level. + Args: + data: Input data, can be a nested sequence. + Returns: + Flattened list if input is a nested sequence, otherwise returns data unchanged. + """ + if isinstance(data, (list, tuple)): + if len(data) == 0: + return data + if all(isinstance(item, (list, tuple)) for item in data): + return [item for sublist in data for item in sublist] + return data diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7dd2397a74..95c59e07bc 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -30,7 +30,7 @@ from monai.data.meta_tensor import MetaObj, MetaTensor from monai.data.utils import no_collation from monai.transforms.inverse import InvertibleTransform -from monai.transforms.traits import MultiSampleTrait, RandomizableTrait +from monai.transforms.traits import MultiSampleTrait, RandomizableTrait, ReduceTrait from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utility.array import ( AddCoordinateChannels, @@ -45,6 +45,7 @@ EnsureChannelFirst, EnsureType, FgBgToIndices, + FlattenSequence, Identity, ImageFilter, IntensityStats, @@ -191,6 +192,9 @@ "ApplyTransformToPointsd", "ApplyTransformToPointsD", "ApplyTransformToPointsDict", + "FlattenSequenced", + "FlattenSequenceD", + "FlattenSequenceDict", ] DEFAULT_POST_FIX = PostFix.meta() @@ -1906,6 +1910,28 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch return d +class FlattenSequenced(MapTransform, ReduceTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`. + + Args: + keys: keys of the corresponding items to be transformed. + See also: monai.transforms.MapTransform + allow_missing_keys: + Don't raise exception if key is missing. + """ + + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, **kwargs) -> None: + super().__init__(keys, allow_missing_keys) + self.flatten_sequence = FlattenSequence(**kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.flatten_sequence(d[key]) # type: ignore[assignment] + return d + + RandImageFilterD = RandImageFilterDict = RandImageFilterd ImageFilterD = ImageFilterDict = ImageFilterd IdentityD = IdentityDict = Identityd @@ -1949,3 +1975,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd +FlattenSequenceD = FlattenSequenceDict = FlattenSequenced diff --git a/tests/transforms/compose/test_compose.py b/tests/transforms/compose/test_compose.py index e6727c976f..12547f9ec2 100644 --- a/tests/transforms/compose/test_compose.py +++ b/tests/transforms/compose/test_compose.py @@ -282,6 +282,40 @@ def test_flatten_and_len(self): def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 + def test_list_extend_multi_sample_trait(self): + center_crop = mt.CenterSpatialCrop([128, 128]) + multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1) + flatten_sequence_transform = mt.FlattenSequence() + + img = torch.zeros([1, 512, 512]) + + self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128])) + single_multi_sample_trait_result = execute_compose( + img, [multi_sample_transform, center_crop, flatten_sequence_transform] + ) + self.assertIsInstance(single_multi_sample_trait_result, list) + self.assertEqual(len(single_multi_sample_trait_result), 1) + self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + + double_multi_sample_trait_result = execute_compose( + img, [multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop] + ) + self.assertIsInstance(double_multi_sample_trait_result, list) + self.assertEqual(len(double_multi_sample_trait_result), 1) + self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64])) + + def test_multi_sample_trait_cardinality(self): + img = torch.zeros([1, 128, 128]) + t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2) + flatten_sequence_transform = mt.FlattenSequence() + + # chaining should multiply counts: 2 x 2 = 4, flattened + res = execute_compose(img, [t2, t2, flatten_sequence_transform]) + self.assertIsInstance(res, list) + self.assertEqual(len(res), 4) + for r in res: + self.assertEqual(r.shape, torch.Size([1, 32, 32])) + TEST_COMPOSE_EXECUTE_TEST_CASES = [ [None, tuple()],