1515import datetime as dt
1616import logging
1717from functools import partial
18- from typing import Any , Optional , Tuple , Union
18+ from typing import Any , Callable , Iterable , List , Optional , Tuple , Union
1919
2020import numpy as np
21+ from geopandas import GeoDataFrame
2122
22- from eolearn .core import EOPatch , EOTask , FeatureType , MapFeatureTask
23+ from eolearn .core import EOPatch , EOTask , FeatureType , FeatureTypeSet , MapFeatureTask
2324
2425from .utils import ResizeLib , ResizeMethod , spatially_resize_image
2526
@@ -34,51 +35,57 @@ class SimpleFilterTask(EOTask):
3435 A filter_func is a callable which takes a numpy array and returns a bool.
3536 """
3637
37- def __init__ (self , feature , filter_func , filter_features = ...):
38+ def __init__ (
39+ self ,
40+ feature : Union [Tuple [FeatureType , str ], FeatureType ],
41+ filter_func : Union [Callable [[np .ndarray ], bool ], Callable [[dt .datetime ], bool ]],
42+ filter_features : Any = ...,
43+ ):
3844 """
3945 :param feature: Feature in the EOPatch , e.g. feature=(FeatureType.DATA, 'bands')
40- :type feature: (FeatureType, str)
4146 :param filter_func: A callable that takes a numpy evaluates to bool.
42- :type filter_func: object
43- :param filter_features: A collection of features which will be filtered
44- :type filter_features: dict(FeatureType: set(str))
47+ :param filter_features: A collection of features which will be filtered into a new EOPatch
4548 """
46- self .feature = self .parse_feature (feature )
49+ self .feature = self .parse_feature (
50+ feature , allowed_feature_types = FeatureTypeSet .TEMPORAL_TYPES .difference ([FeatureType .VECTOR ])
51+ )
4752 self .filter_func = filter_func
4853 self .filter_features_parser = self .get_feature_parser (filter_features )
4954
50- def _get_filtered_indices (self , feature_data ):
55+ def _get_filtered_indices (self , feature_data : Iterable ) -> List [int ]:
56+ """Get valid time indices from either a numpy array or a list of timestamps."""
5157 return [idx for idx , img in enumerate (feature_data ) if self .filter_func (img )]
5258
53- def _update_other_data (self , eopatch ):
54- pass
59+ @staticmethod
60+ def _filter_vector_feature (gdf : GeoDataFrame , good_idxs : List [int ], timestamps : List [dt .datetime ]) -> GeoDataFrame :
61+ """Filters rows that don't match with the timestamps that will be kept."""
62+ timestamps_to_keep = set (timestamps [idx ] for idx in good_idxs )
63+ return gdf [gdf .TIMESTAMP .isin (timestamps_to_keep )]
5564
56- def execute (self , eopatch ) :
65+ def execute (self , eopatch : EOPatch ) -> EOPatch :
5766 """
58- :param eopatch: Input EOPatch.
59- :type eopatch: EOPatch
60- :return: Transformed eo patch
61- :rtype: EOPatch
67+ :param eopatch: An input EOPatch.
68+ :return: A new EOPatch with filtered features.
6269 """
70+ filtered_eopatch = EOPatch ()
6371 good_idxs = self ._get_filtered_indices (eopatch [self .feature ])
64- if not good_idxs :
65- raise RuntimeError ("EOPatch has no good indices after filtering with given filter function" )
66-
67- for feature_type , feature_name in self .filter_features_parser .get_features (eopatch ):
68- if feature_type .is_temporal ():
69- if feature_type .has_dict ():
70- if feature_type .contains_ndarrays ():
71- eopatch [feature_type ][feature_name ] = np .asarray (
72- [eopatch [feature_type ][feature_name ][idx ] for idx in good_idxs ]
73- )
74- # else:
75- # NotImplemented
72+
73+ for feature in self .filter_features_parser .get_features (eopatch ):
74+ feature_type , _ = feature
75+ data = eopatch [feature ]
76+
77+ if feature_type is FeatureType .TIMESTAMP :
78+ data = [data [idx ] for idx in good_idxs ]
79+
80+ elif feature_type .is_temporal ():
81+ if feature_type .is_raster ():
82+ data = data [good_idxs ]
7683 else :
77- eopatch [ feature_type ] = [ eopatch [ feature_type ][ idx ] for idx in good_idxs ]
84+ data = self . _filter_vector_feature ( data , good_idxs , eopatch . timestamp )
7885
79- self . _update_other_data ( eopatch )
86+ filtered_eopatch [ feature ] = data
8087
81- return eopatch
88+ return filtered_eopatch
8289
8390
8491class FilterTimeSeriesTask (SimpleFilterTask ):
0 commit comments