|
10 | 10 |
|
11 | 11 | from nowcasting_dataset import consts |
12 | 12 | from nowcasting_dataset import data_sources |
13 | | -from nowcasting_dataset import time as nd_time |
14 | | -from nowcasting_dataset import utils |
15 | 13 | from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource |
16 | 14 | from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource |
17 | 15 | from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource |
18 | 16 | from nowcasting_dataset.dataset import datasets |
19 | 17 | from nowcasting_dataset.dataset.split.split import split_data, SplitMethod |
| 18 | +from nowcasting_dataset.data_sources.data_source_list import DataSourceList |
| 19 | + |
20 | 20 |
|
21 | 21 | with warnings.catch_warnings(): |
22 | 22 | warnings.filterwarnings("ignore", category=DeprecationWarning) |
@@ -207,6 +207,8 @@ def prepare_data(self) -> None: |
207 | 207 | ) |
208 | 208 | ) |
209 | 209 |
|
| 210 | + self.data_sources = DataSourceList(self.data_sources) |
| 211 | + |
210 | 212 | def setup(self, stage="fit"): |
211 | 213 | """Split data, etc. |
212 | 214 |
|
@@ -309,14 +311,18 @@ def _split_data(self): |
309 | 311 | logger.debug("Going to split data") |
310 | 312 |
|
311 | 313 | self._check_has_prepared_data() |
312 | | - self.t0_datetimes = self._get_t0_datetimes() |
| 314 | + self.t0_datetimes = self._get_t0_datetimes_across_all_data_sources() |
313 | 315 |
|
314 | 316 | logger.debug(f"Got all start times, there are {len(self.t0_datetimes):,d}") |
315 | 317 |
|
316 | | - self.train_t0_datetimes, self.val_t0_datetimes, self.test_t0_datetimes = split_data( |
| 318 | + data_after_splitting = split_data( |
317 | 319 | datetimes=self.t0_datetimes, method=self.split_method, seed=self.seed |
318 | 320 | ) |
319 | 321 |
|
| 322 | + self.train_t0_datetimes = data_after_splitting.train |
| 323 | + self.val_t0_datetimes = data_after_splitting.validation |
| 324 | + self.test_t0_datetimes = data_after_splitting.test |
| 325 | + |
320 | 326 | logger.debug( |
321 | 327 | f"Split data done, train has {len(self.train_t0_datetimes):,d}, " |
322 | 328 | f"validation has {len(self.val_t0_datetimes):,d}, " |
@@ -354,38 +360,15 @@ def _common_dataloader_params(self) -> Dict: |
354 | 360 | batch_sampler=None, |
355 | 361 | ) |
356 | 362 |
|
357 | | - def _get_t0_datetimes(self) -> pd.DatetimeIndex: |
358 | | - """ |
359 | | - Compute the intersection of the t0 datetimes available across all DataSources. |
| 363 | + def _get_t0_datetimes_across_all_data_sources(self) -> pd.DatetimeIndex: |
| 364 | + """See DataSourceList.get_t0_datetimes_across_all_data_sources. |
360 | 365 |
|
361 | | - Returns the valid t0 datetimes, taking into consideration all DataSources, |
362 | | - filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night |
363 | | - datetimes). |
| 366 | + This method will be deleted as part of implementing #213. |
364 | 367 | """ |
365 | | - logger.debug("Get the intersection of time periods across all DataSources.") |
366 | | - self._check_has_prepared_data() |
367 | | - |
368 | | - # Get the intersection of t0 time periods from all data sources. |
369 | | - t0_time_periods_for_all_data_sources = [] |
370 | | - for data_source in self.data_sources: |
371 | | - logger.debug(f"Getting t0 time periods for {type(data_source).__name__}") |
372 | | - try: |
373 | | - t0_time_periods = data_source.get_contiguous_t0_time_periods() |
374 | | - except NotImplementedError: |
375 | | - pass # Skip data_sources with no concept of time. |
376 | | - else: |
377 | | - t0_time_periods_for_all_data_sources.append(t0_time_periods) |
378 | | - |
379 | | - intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods( |
380 | | - t0_time_periods_for_all_data_sources |
381 | | - ) |
382 | | - |
383 | | - t0_datetimes = nd_time.time_periods_to_datetime_index( |
384 | | - time_periods=intersection_of_t0_time_periods, freq=self.t0_datetime_freq |
| 368 | + return self.data_sources.get_t0_datetimes_across_all_data_sources( |
| 369 | + freq=self.t0_datetime_freq |
385 | 370 | ) |
386 | 371 |
|
387 | | - return t0_datetimes |
388 | | - |
389 | 372 | def _check_has_prepared_data(self): |
390 | 373 | if not self.has_prepared_data: |
391 | 374 | raise RuntimeError("Must run prepare_data() first!") |
0 commit comments