Skip to content

Commit f69b86f

Browse files
authored
Merge pull request #438 from sentinel-hub/feat/simple-filter-task
Updated SimpleFilterTask and wrote additional unit tests
2 parents 2f035d1 + 2ea5e97 commit f69b86f

File tree

2 files changed

+73
-40
lines changed

2 files changed

+73
-40
lines changed

features/eolearn/features/feature_manipulation.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
import datetime as dt
1616
import logging
1717
from functools import partial
18-
from typing import Any, Optional, Tuple, Union
18+
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
1919

2020
import numpy as np
21+
from geopandas import GeoDataFrame
2122

22-
from eolearn.core import EOPatch, EOTask, FeatureType, MapFeatureTask
23+
from eolearn.core import EOPatch, EOTask, FeatureType, FeatureTypeSet, MapFeatureTask
2324

2425
from .utils import ResizeLib, ResizeMethod, spatially_resize_image
2526

@@ -34,51 +35,57 @@ class SimpleFilterTask(EOTask):
3435
A filter_func is a callable which takes a numpy array and returns a bool.
3536
"""
3637

37-
def __init__(self, feature, filter_func, filter_features=...):
38+
def __init__(
39+
self,
40+
feature: Union[Tuple[FeatureType, str], FeatureType],
41+
filter_func: Union[Callable[[np.ndarray], bool], Callable[[dt.datetime], bool]],
42+
filter_features: Any = ...,
43+
):
3844
"""
3945
:param feature: Feature in the EOPatch , e.g. feature=(FeatureType.DATA, 'bands')
40-
:type feature: (FeatureType, str)
4146
:param filter_func: A callable that takes a numpy evaluates to bool.
42-
:type filter_func: object
43-
:param filter_features: A collection of features which will be filtered
44-
:type filter_features: dict(FeatureType: set(str))
47+
:param filter_features: A collection of features which will be filtered into a new EOPatch
4548
"""
46-
self.feature = self.parse_feature(feature)
49+
self.feature = self.parse_feature(
50+
feature, allowed_feature_types=FeatureTypeSet.TEMPORAL_TYPES.difference([FeatureType.VECTOR])
51+
)
4752
self.filter_func = filter_func
4853
self.filter_features_parser = self.get_feature_parser(filter_features)
4954

50-
def _get_filtered_indices(self, feature_data):
55+
def _get_filtered_indices(self, feature_data: Iterable) -> List[int]:
56+
"""Get valid time indices from either a numpy array or a list of timestamps."""
5157
return [idx for idx, img in enumerate(feature_data) if self.filter_func(img)]
5258

53-
def _update_other_data(self, eopatch):
54-
pass
59+
@staticmethod
60+
def _filter_vector_feature(gdf: GeoDataFrame, good_idxs: List[int], timestamps: List[dt.datetime]) -> GeoDataFrame:
61+
"""Filters rows that don't match with the timestamps that will be kept."""
62+
timestamps_to_keep = set(timestamps[idx] for idx in good_idxs)
63+
return gdf[gdf.TIMESTAMP.isin(timestamps_to_keep)]
5564

56-
def execute(self, eopatch):
65+
def execute(self, eopatch: EOPatch) -> EOPatch:
5766
"""
58-
:param eopatch: Input EOPatch.
59-
:type eopatch: EOPatch
60-
:return: Transformed eo patch
61-
:rtype: EOPatch
67+
:param eopatch: An input EOPatch.
68+
:return: A new EOPatch with filtered features.
6269
"""
70+
filtered_eopatch = EOPatch()
6371
good_idxs = self._get_filtered_indices(eopatch[self.feature])
64-
if not good_idxs:
65-
raise RuntimeError("EOPatch has no good indices after filtering with given filter function")
66-
67-
for feature_type, feature_name in self.filter_features_parser.get_features(eopatch):
68-
if feature_type.is_temporal():
69-
if feature_type.has_dict():
70-
if feature_type.contains_ndarrays():
71-
eopatch[feature_type][feature_name] = np.asarray(
72-
[eopatch[feature_type][feature_name][idx] for idx in good_idxs]
73-
)
74-
# else:
75-
# NotImplemented
72+
73+
for feature in self.filter_features_parser.get_features(eopatch):
74+
feature_type, _ = feature
75+
data = eopatch[feature]
76+
77+
if feature_type is FeatureType.TIMESTAMP:
78+
data = [data[idx] for idx in good_idxs]
79+
80+
elif feature_type.is_temporal():
81+
if feature_type.is_raster():
82+
data = data[good_idxs]
7683
else:
77-
eopatch[feature_type] = [eopatch[feature_type][idx] for idx in good_idxs]
84+
data = self._filter_vector_feature(data, good_idxs, eopatch.timestamp)
7885

79-
self._update_other_data(eopatch)
86+
filtered_eopatch[feature] = data
8087

81-
return eopatch
88+
return filtered_eopatch
8289

8390

8491
class FilterTimeSeriesTask(SimpleFilterTask):

features/eolearn/tests/test_feature_manipulation.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,39 @@
1616
from numpy.testing import assert_allclose, assert_array_equal
1717

1818
from eolearn.core import EOPatch, FeatureType
19-
from eolearn.features import FilterTimeSeriesTask, LinearFunctionTask, ValueFilloutTask
19+
from eolearn.features import FilterTimeSeriesTask, LinearFunctionTask, SimpleFilterTask, ValueFilloutTask
2020
from eolearn.features.feature_manipulation import SpatialResizeTask
2121

2222

23-
def test_content_after_timefilter():
23+
@pytest.mark.parametrize(
24+
"feature", [(FeatureType.DATA, "BANDS-S2-L1C"), FeatureType.TIMESTAMP, (FeatureType.LABEL, "IS_CLOUDLESS")]
25+
)
26+
def test_simple_filter_task_filter_all(example_eopatch: EOPatch, feature):
27+
filter_all_task = SimpleFilterTask(feature, filter_func=lambda _: False)
28+
filtered_eopatch = filter_all_task.execute(example_eopatch)
29+
30+
assert filtered_eopatch is not example_eopatch
31+
assert filtered_eopatch.data["CLP"].shape == (0, 101, 100, 1)
32+
assert filtered_eopatch.scalar["CLOUD_COVERAGE"].shape == (0, 1)
33+
assert len(filtered_eopatch.vector["CLM_VECTOR"]) == 0
34+
assert np.array_equal(filtered_eopatch.mask_timeless["LULC"], example_eopatch.mask_timeless["LULC"])
35+
assert filtered_eopatch.timestamp == []
36+
37+
38+
@pytest.mark.parametrize(
39+
"feature", [(FeatureType.MASK, "CLM"), FeatureType.TIMESTAMP, (FeatureType.SCALAR, "CLOUD_COVERAGE")]
40+
)
41+
def test_simple_filter_task_filter_nothing(example_eopatch: EOPatch, feature):
42+
del example_eopatch.data["REFERENCE_SCENES"] # Wrong size of time dimension
43+
44+
filter_all_task = SimpleFilterTask(feature, filter_func=lambda _: True)
45+
filtered_eopatch = filter_all_task.execute(example_eopatch)
46+
47+
assert filtered_eopatch is not example_eopatch
48+
assert filtered_eopatch == example_eopatch
49+
50+
51+
def test_content_after_time_filter():
2452
timestamps = [
2553
datetime.datetime(2017, 1, 1, 10, 4, 7),
2654
datetime.datetime(2017, 1, 4, 10, 14, 5),
@@ -37,16 +65,14 @@ def test_content_after_timefilter():
3765

3866
new_start, new_end = 4, -3
3967

40-
new_interval = (timestamps[new_start], timestamps[new_end])
41-
42-
new_timestamps = timestamps[new_start : new_end + 1]
43-
4468
eop = EOPatch(timestamp=timestamps, data={"data": data})
4569

46-
filter_task = FilterTimeSeriesTask(start_date=new_interval[0], end_date=new_interval[1])
47-
filter_task.execute(eop)
70+
filter_task = FilterTimeSeriesTask(start_date=timestamps[new_start], end_date=timestamps[new_end])
71+
filtered_eop = filter_task.execute(eop)
4872

49-
assert new_timestamps == eop.timestamp
73+
assert filtered_eop is not eop
74+
assert filtered_eop.timestamp == timestamps[new_start : new_end + 1]
75+
assert np.array_equal(filtered_eop.data["data"], data[new_start : new_end + 1, ...])
5076

5177

5278
def test_fill():

0 commit comments

Comments
 (0)