@@ -53,12 +53,9 @@ def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeInde
5353
5454 return t0_datetimes
5555
56- def sample_position_of_every_example_of_every_split (
57- self ,
58- t0_datetimes : pd .DatetimeIndex ,
59- split_method : SplitMethod ,
60- n_examples_per_split : dict [SplitName , int ],
61- ) -> dict [SplitName , pd .DataFrame ]:
56+ def sample_spatial_and_temporal_positions_for_examples (
57+ self , t0_datetimes : pd .DatetimeIndex , n_examples : int
58+ ) -> pd .DataFrame :
6259 """
6360 Computes the geospatial and temporal position of each training example.
6461
@@ -68,33 +65,21 @@ def sample_position_of_every_example_of_every_split(
6865 Args:
6966 t0_datetimes: All available t0 datetimes. Can be computed with
7067 `DataSourceList.get_t0_datetimes_across_all_data_sources()`
71- split_method: The method used to split data into train, validation, and test.
72- n_examples_per_split: The number of examples requested for each split.
68+ n_examples: The number of examples requested.
7369
7470 Returns:
75- A dict where the keys are a SplitName, and the values are a pd.DataFrame.
76- Each row of each DataFrame specifies the position of each example, using
71+ Each row of each the DataFrame specifies the position of each example, using
7772 columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
7873 """
79- # Split t0_datetimes into train, test and validation sets.
80- t0_datetimes_per_split = split_data (datetimes = t0_datetimes , method = split_method )
81- t0_datetimes_per_split = t0_datetimes_per_split ._asdict ()
82-
8374 data_source_which_defines_geo_position = self [0 ]
84-
85- positions_per_split : dict [SplitName , pd .DataFrame ] = {}
86- for split_name , t0_datetimes_for_split in t0_datetimes_per_split .items ():
87- n_examples = n_examples_per_split [split_name ]
88- shuffled_t0_datetimes = np .random .choice (t0_datetimes_for_split , shape = n_examples )
89- x_locations , y_locations = data_source_which_defines_geo_position .get_locations (
90- shuffled_t0_datetimes
91- )
92- positions_per_split [split_name ] = pd .DataFrame (
93- {
94- "t0_datetime_UTC" : shuffled_t0_datetimes ,
95- "x_center_OSGB" : x_locations ,
96- "y_center_OSGB" : y_locations ,
97- }
98- )
99-
100- return positions_per_split
75+ shuffled_t0_datetimes = np .random .choice (t0_datetimes , size = n_examples )
76+ x_locations , y_locations = data_source_which_defines_geo_position .get_locations (
77+ shuffled_t0_datetimes
78+ )
79+ return pd .DataFrame (
80+ {
81+ "t0_datetime_UTC" : shuffled_t0_datetimes ,
82+ "x_center_OSGB" : x_locations ,
83+ "y_center_OSGB" : y_locations ,
84+ }
85+ )
0 commit comments