Skip to content

Commit 579cec5

Browse files
lukas-folle-snkeosericspodKumoLiu
authored
added ReduceTrait and FlattenSequence (#8531)
Fixes #8528. ### Description This PR adds the `FlatttenSequence` transform (a flavor of the also newly added `ReduceTrait`) which can flatten a nested data structure by one level. This way, #8528 can be tackled without the need to change the `apply_transform` of `Compose` significantly. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lukas Folle <lukas.folle@snke.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent d260b78 commit 579cec5

File tree

7 files changed

+137
-8
lines changed

7 files changed

+137
-8
lines changed

docs/source/transforms.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Generic Interfaces
3737
.. autoclass:: MultiSampleTrait
3838
:members:
3939

40+
`ReduceTrait`
41+
^^^^^^^^^^^^^^^^^^
42+
.. autoclass:: ReduceTrait
43+
:members:
44+
4045
`Randomizable`
4146
^^^^^^^^^^^^^^
4247
.. autoclass:: Randomizable
@@ -1252,6 +1257,12 @@ Utility
12521257
:members:
12531258
:special-members: __call__
12541259

1260+
`FlattenSequence`
1261+
""""""""""""""""""""""""
1262+
.. autoclass:: FlattenSequence
1263+
:members:
1264+
:special-members: __call__
1265+
12551266
Dictionary Transforms
12561267
---------------------
12571268

@@ -2337,6 +2348,12 @@ Utility (Dict)
23372348
:members:
23382349
:special-members: __call__
23392350

2351+
`FlattenSequenced`
2352+
"""""""""""""""""""""""""
2353+
.. autoclass:: FlattenSequenced
2354+
:members:
2355+
:special-members: __call__
2356+
23402357

23412358
MetaTensor
23422359
^^^^^^^^^^

monai/transforms/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@
506506
ZoomDict,
507507
)
508508
from .spatial.functional import spatial_resample
509-
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe
509+
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe
510510
from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform
511511
from .utility.array import (
512512
AddCoordinateChannels,
@@ -521,6 +521,7 @@
521521
EnsureChannelFirst,
522522
EnsureType,
523523
FgBgToIndices,
524+
FlattenSequence,
524525
Identity,
525526
ImageFilter,
526527
IntensityStats,
@@ -593,6 +594,9 @@
593594
FgBgToIndicesd,
594595
FgBgToIndicesD,
595596
FgBgToIndicesDict,
597+
FlattenSequenced,
598+
FlattenSequenceD,
599+
FlattenSequenceDict,
596600
FlattenSubKeysd,
597601
FlattenSubKeysD,
598602
FlattenSubKeysDict,

monai/transforms/traits.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"]
17+
__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe", "ReduceTrait"]
1818

1919
from typing import Any
2020

@@ -99,3 +99,14 @@ class ThreadUnsafe:
9999
"""
100100

101101
pass
102+
103+
104+
class ReduceTrait:
105+
"""
106+
An interface to indicate that the transform has the capability to reduce multiple samples
107+
into a single sample.
108+
This interface can be extended from by people adapting transforms to the MONAI framework as well
109+
as by implementors of MONAI transforms.
110+
"""
111+
112+
pass

monai/transforms/transform.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from monai import config, transforms
2626
from monai.config import KeysCollection
2727
from monai.data.meta_tensor import MetaTensor
28-
from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe
28+
from monai.transforms.traits import LazyTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe
2929
from monai.utils import MAX_SEED, ensure_tuple, first
3030
from monai.utils.enums import TransformBackends
3131
from monai.utils.misc import MONAIEnvVars
@@ -142,7 +142,7 @@ def apply_transform(
142142
"""
143143
try:
144144
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
145-
if isinstance(data, (list, tuple)) and map_items_ > 0:
145+
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
146146
return [
147147
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
148148
for item in data
@@ -482,8 +482,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable
482482
yield (key,) + tuple(_ex_iters) if extra_iterables else key
483483
elif not self.allow_missing_keys:
484484
raise KeyError(
485-
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
486-
" and allow_missing_keys==False."
485+
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False."
487486
)
488487

489488
def first_key(self, data: dict[Hashable, Any]):

monai/transforms/utility/array.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
median_filter,
4444
)
4545
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
46-
from monai.transforms.traits import MultiSampleTrait
46+
from monai.transforms.traits import MultiSampleTrait, ReduceTrait
4747
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
4848
from monai.transforms.utils import (
4949
apply_affine_to_points,
@@ -110,6 +110,7 @@
110110
"ImageFilter",
111111
"RandImageFilter",
112112
"ApplyTransformToPoints",
113+
"FlattenSequence",
113114
]
114115

115116

@@ -1950,3 +1951,39 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
19501951
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])
19511952

19521953
return data
1954+
1955+
1956+
class FlattenSequence(Transform, ReduceTrait):
1957+
"""
1958+
Flatten a nested sequence (list or tuple) by one level.
1959+
If the input is a sequence of sequences, it will flatten them into a single sequence.
1960+
Non-nested sequences and other data types are returned unchanged.
1961+
1962+
For example:
1963+
1964+
.. code-block:: python
1965+
1966+
flatten = FlattenSequence()
1967+
data = [[1, 2], [3, 4], [5, 6]]
1968+
print(flatten(data))
1969+
[1, 2, 3, 4, 5, 6]
1970+
1971+
"""
1972+
1973+
def __init__(self):
1974+
super().__init__()
1975+
1976+
def __call__(self, data: list | tuple | Any) -> list | tuple | Any:
1977+
"""
1978+
Flatten a nested sequence by one level.
1979+
Args:
1980+
data: Input data, can be a nested sequence.
1981+
Returns:
1982+
Flattened list if input is a nested sequence, otherwise returns data unchanged.
1983+
"""
1984+
if isinstance(data, (list, tuple)):
1985+
if len(data) == 0:
1986+
return data
1987+
if all(isinstance(item, (list, tuple)) for item in data):
1988+
return [item for sublist in data for item in sublist]
1989+
return data

monai/transforms/utility/dictionary.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from monai.data.meta_tensor import MetaObj, MetaTensor
3131
from monai.data.utils import no_collation
3232
from monai.transforms.inverse import InvertibleTransform
33-
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait
33+
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait, ReduceTrait
3434
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
3535
from monai.transforms.utility.array import (
3636
AddCoordinateChannels,
@@ -45,6 +45,7 @@
4545
EnsureChannelFirst,
4646
EnsureType,
4747
FgBgToIndices,
48+
FlattenSequence,
4849
Identity,
4950
ImageFilter,
5051
IntensityStats,
@@ -191,6 +192,9 @@
191192
"ApplyTransformToPointsd",
192193
"ApplyTransformToPointsD",
193194
"ApplyTransformToPointsDict",
195+
"FlattenSequenced",
196+
"FlattenSequenceD",
197+
"FlattenSequenceDict",
194198
]
195199

196200
DEFAULT_POST_FIX = PostFix.meta()
@@ -1906,6 +1910,28 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19061910
return d
19071911

19081912

1913+
class FlattenSequenced(MapTransform, ReduceTrait):
1914+
"""
1915+
Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.
1916+
1917+
Args:
1918+
keys: keys of the corresponding items to be transformed.
1919+
See also: monai.transforms.MapTransform
1920+
allow_missing_keys:
1921+
Don't raise exception if key is missing.
1922+
"""
1923+
1924+
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, **kwargs) -> None:
1925+
super().__init__(keys, allow_missing_keys)
1926+
self.flatten_sequence = FlattenSequence(**kwargs)
1927+
1928+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
1929+
d = dict(data)
1930+
for key in self.key_iterator(d):
1931+
d[key] = self.flatten_sequence(d[key]) # type: ignore[assignment]
1932+
return d
1933+
1934+
19091935
RandImageFilterD = RandImageFilterDict = RandImageFilterd
19101936
ImageFilterD = ImageFilterDict = ImageFilterd
19111937
IdentityD = IdentityDict = Identityd
@@ -1949,3 +1975,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19491975
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
19501976
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
19511977
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
1978+
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced

tests/transforms/compose/test_compose.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,40 @@ def test_flatten_and_len(self):
282282
def test_backwards_compatible_imports(self):
283283
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401
284284

285+
def test_list_extend_multi_sample_trait(self):
286+
center_crop = mt.CenterSpatialCrop([128, 128])
287+
multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1)
288+
flatten_sequence_transform = mt.FlattenSequence()
289+
290+
img = torch.zeros([1, 512, 512])
291+
292+
self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128]))
293+
single_multi_sample_trait_result = execute_compose(
294+
img, [multi_sample_transform, center_crop, flatten_sequence_transform]
295+
)
296+
self.assertIsInstance(single_multi_sample_trait_result, list)
297+
self.assertEqual(len(single_multi_sample_trait_result), 1)
298+
self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))
299+
300+
double_multi_sample_trait_result = execute_compose(
301+
img, [multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop]
302+
)
303+
self.assertIsInstance(double_multi_sample_trait_result, list)
304+
self.assertEqual(len(double_multi_sample_trait_result), 1)
305+
self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))
306+
307+
def test_multi_sample_trait_cardinality(self):
308+
img = torch.zeros([1, 128, 128])
309+
t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2)
310+
flatten_sequence_transform = mt.FlattenSequence()
311+
312+
# chaining should multiply counts: 2 x 2 = 4, flattened
313+
res = execute_compose(img, [t2, t2, flatten_sequence_transform])
314+
self.assertIsInstance(res, list)
315+
self.assertEqual(len(res), 4)
316+
for r in res:
317+
self.assertEqual(r.shape, torch.Size([1, 32, 32]))
318+
285319

286320
TEST_COMPOSE_EXECUTE_TEST_CASES = [
287321
[None, tuple()],

0 commit comments

Comments
 (0)