Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 3143ebb

Browse files
authored
Merge pull request #278 from openclimatefix/jack/compute_and_save_positions_of_each_example
Implement `DataSourceList.sample_spatial_and_temporal_locations_for_examples()`
2 parents cfe8977 + d9c0715 commit 3143ebb

File tree

14 files changed

+144
-102
lines changed

14 files changed

+144
-102
lines changed

nowcasting_dataset/data_sources/data_source.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from numbers import Number
66
from typing import List, Tuple, Iterable
77

8-
import numpy as np
98
import pandas as pd
109
import xarray as xr
1110

@@ -181,13 +180,10 @@ def get_contiguous_time_periods(self) -> pd.DataFrame:
181180
max_gap_duration=self.sample_period_duration,
182181
)
183182

184-
def get_locations_for_batch(
185-
self, t0_datetimes: pd.DatetimeIndex
186-
) -> Tuple[List[Number], List[Number]]:
183+
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
187184
"""Find a valid geographical locations for each t0_datetime.
188185
189-
Should be overridden by DataSources which may be used to define the locations
190-
for each batch.
186+
Should be overridden by DataSources which may be used to define the locations.
191187
192188
Returns: x_locations, y_locations. Each has one entry per t0_datetime.
193189
Locations are in OSGB coordinates.
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""DataSourceList class."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
import logging
6+
7+
import nowcasting_dataset.time as nd_time
8+
from nowcasting_dataset.dataset.split.split import SplitMethod, split_data, SplitName
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class DataSourceList(list):
14+
"""Hold a list of DataSource objects.
15+
16+
The first DataSource in the list is used to compute the geospatial locations of each example.
17+
"""
18+
19+
def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex:
20+
"""
21+
Compute the intersection of the t0 datetimes available across all DataSources.
22+
23+
Args:
24+
freq: The Pandas frequency string. The returned DatetimeIndex will be at this frequency,
25+
and every datetime will be aligned to this frequency. For example, if
26+
freq='5 minutes' then every datetime will be at 00, 05, ..., 55 minutes
27+
past the hour.
28+
29+
Returns: Valid t0 datetimes, taking into consideration all DataSources,
30+
filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night
31+
datetimes).
32+
"""
33+
logger.debug("Get the intersection of time periods across all DataSources.")
34+
35+
# Get the intersection of t0 time periods from all data sources.
36+
t0_time_periods_for_all_data_sources = []
37+
for data_source in self:
38+
logger.debug(f"Getting t0 time periods for {type(data_source).__name__}")
39+
try:
40+
t0_time_periods = data_source.get_contiguous_t0_time_periods()
41+
except NotImplementedError:
42+
pass # Skip data_sources with no concept of time.
43+
else:
44+
t0_time_periods_for_all_data_sources.append(t0_time_periods)
45+
46+
intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods(
47+
t0_time_periods_for_all_data_sources
48+
)
49+
50+
t0_datetimes = nd_time.time_periods_to_datetime_index(
51+
time_periods=intersection_of_t0_time_periods, freq=freq
52+
)
53+
54+
return t0_datetimes
55+
56+
def sample_spatial_and_temporal_locations_for_examples(
57+
self, t0_datetimes: pd.DatetimeIndex, n_examples: int
58+
) -> pd.DataFrame:
59+
"""
60+
Computes the geospatial and temporal locations for each training example.
61+
62+
The first data_source in this DataSourceList defines the geospatial locations of
63+
each example.
64+
65+
Args:
66+
t0_datetimes: All available t0 datetimes. Can be computed with
67+
`DataSourceList.get_t0_datetimes_across_all_data_sources()`
68+
n_examples: The number of examples requested.
69+
70+
Returns:
71+
Each row of each the DataFrame specifies the position of each example, using
72+
columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
73+
"""
74+
data_source_which_defines_geo_position = self[0]
75+
shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples)
76+
x_locations, y_locations = data_source_which_defines_geo_position.get_locations(
77+
shuffled_t0_datetimes
78+
)
79+
return pd.DataFrame(
80+
{
81+
"t0_datetime_UTC": shuffled_t0_datetimes,
82+
"x_center_OSGB": x_locations,
83+
"y_center_OSGB": y_locations,
84+
}
85+
)

nowcasting_dataset/data_sources/datetime/datetime_data_source.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
class DatetimeDataSource(DataSource):
1616
""" Add hour_of_day_{sin, cos} and day_of_year_{sin, cos} features. """
1717

18-
def __post_init__(self):
19-
""" Post init """
20-
super().__post_init__()
21-
2218
def get_example(
2319
self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number
2420
) -> Datetime:
@@ -44,13 +40,3 @@ def get_example(
4440
datetime_xr_dataset = make_dim_index(datetime_xr_dataset)
4541

4642
return Datetime(datetime_xr_dataset)
47-
48-
def get_locations_for_batch(
49-
self, t0_datetimes: pd.DatetimeIndex
50-
) -> Tuple[List[Number], List[Number]]:
51-
""" This method is not needed for DatetimeDataSource """
52-
raise NotImplementedError()
53-
54-
def datetime_index(self) -> pd.DatetimeIndex:
55-
""" This method is not needed for DatetimeDataSource """
56-
raise NotImplementedError()

nowcasting_dataset/data_sources/gsp/gsp_data_source.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,15 @@ def datetime_index(self):
108108
"""
109109
return self.gsp_power.index
110110

111-
def get_locations_for_batch(
112-
self, t0_datetimes: pd.DatetimeIndex
113-
) -> Tuple[List[Number], List[Number]]:
111+
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
114112
"""
115-
Get x and y locations for a batch. Assume that all data is available for all GSP.
113+
Get x and y locations. Assume that all data is available for all GSP.
116114
117115
Random GSP are taken, and the locations of them are returned. This is useful as other
118116
datasources need to know which x,y locations to get.
119117
120118
Args:
121-
t0_datetimes: list of datetimes that the batches locations have data for
119+
t0_datetimes: list of available t0 datetimes.
122120
123121
Returns: list of x and y locations
124122
@@ -266,7 +264,7 @@ def _get_central_gsp_id(
266264
logger.debug("Getting Central GSP")
267265

268266
# If x_meters_center and y_meters_center have been chosen
269-
# by {}.get_locations_for_batch() then we just have
267+
# by {}.get_locations() then we just have
270268
# to find the gsp_ids at that exact location. This is
271269
# super-fast (a few hundred microseconds). We use np.isclose
272270
# instead of the equality operator because floats.

nowcasting_dataset/data_sources/metadata/metadata_data_source.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ class MetadataDataSource(DataSource):
1919

2020
object_at_center: str = "GSP"
2121

22-
def __post_init__(self):
23-
"""Post init"""
24-
super().__post_init__()
25-
2622
def get_example(
2723
self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number
2824
) -> Metadata:
@@ -44,6 +40,8 @@ def get_example(
4440
else:
4541
object_at_center_label = 0
4642

43+
# TODO: data_dict is unused in this function. Is that a bug?
44+
# https://github.com/openclimatefix/nowcasting_dataset/issues/279
4745
data_dict = dict(
4846
t0_dt=to_numpy(t0_dt), #: Shape: [batch_size,]
4947
x_meters_center=np.array(x_meters_center),
@@ -68,13 +66,3 @@ def get_example(
6866
data[v] = getattr(d, v)
6967

7068
return Metadata(data)
71-
72-
def get_locations_for_batch(
73-
self, t0_datetimes: pd.DatetimeIndex
74-
) -> Tuple[List[Number], List[Number]]:
75-
"""This method is not needed for MetadataDataSource"""
76-
raise NotImplementedError()
77-
78-
def datetime_index(self) -> pd.DatetimeIndex:
79-
"""This method is not needed for MetadataDataSource"""
80-
raise NotImplementedError()

nowcasting_dataset/data_sources/pv/pv_data_source.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,7 @@ def get_example(
276276

277277
return PV(pv)
278278

279-
def get_locations_for_batch(
280-
self, t0_datetimes: pd.DatetimeIndex
281-
) -> Tuple[List[Number], List[Number]]:
279+
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
282280
"""Find a valid geographical location for each t0_datetime.
283281
284282
Returns: x_locations, y_locations. Each has one entry per t0_datetime.

nowcasting_dataset/data_sources/sun/sun_data_source.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,11 @@ def get_example(
7878
return Sun(sun)
7979

8080
def _load(self):
81-
8281
self.azimuth, self.elevation = load_from_zarr(
8382
filename=self.filename, start_dt=self.start_dt, end_dt=self.end_dt
8483
)
8584

86-
def get_locations_for_batch(
87-
self, t0_datetimes: pd.DatetimeIndex
88-
) -> Tuple[List[Number], List[Number]]:
85+
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
8986
""" Sun data should not be used to get batch locations """
9087
raise NotImplementedError("Sun data should not be used to get batch locations")
9188

nowcasting_dataset/dataset/datamodule.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
from nowcasting_dataset import consts
1212
from nowcasting_dataset import data_sources
13-
from nowcasting_dataset import time as nd_time
14-
from nowcasting_dataset import utils
1513
from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource
1614
from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource
1715
from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource
1816
from nowcasting_dataset.dataset import datasets
1917
from nowcasting_dataset.dataset.split.split import split_data, SplitMethod
18+
from nowcasting_dataset.data_sources.data_source_list import DataSourceList
19+
2020

2121
with warnings.catch_warnings():
2222
warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -207,6 +207,8 @@ def prepare_data(self) -> None:
207207
)
208208
)
209209

210+
self.data_sources = DataSourceList(self.data_sources)
211+
210212
def setup(self, stage="fit"):
211213
"""Split data, etc.
212214
@@ -309,14 +311,18 @@ def _split_data(self):
309311
logger.debug("Going to split data")
310312

311313
self._check_has_prepared_data()
312-
self.t0_datetimes = self._get_t0_datetimes()
314+
self.t0_datetimes = self._get_t0_datetimes_across_all_data_sources()
313315

314316
logger.debug(f"Got all start times, there are {len(self.t0_datetimes):,d}")
315317

316-
self.train_t0_datetimes, self.val_t0_datetimes, self.test_t0_datetimes = split_data(
318+
data_after_splitting = split_data(
317319
datetimes=self.t0_datetimes, method=self.split_method, seed=self.seed
318320
)
319321

322+
self.train_t0_datetimes = data_after_splitting.train
323+
self.val_t0_datetimes = data_after_splitting.validation
324+
self.test_t0_datetimes = data_after_splitting.test
325+
320326
logger.debug(
321327
f"Split data done, train has {len(self.train_t0_datetimes):,d}, "
322328
f"validation has {len(self.val_t0_datetimes):,d}, "
@@ -354,38 +360,15 @@ def _common_dataloader_params(self) -> Dict:
354360
batch_sampler=None,
355361
)
356362

357-
def _get_t0_datetimes(self) -> pd.DatetimeIndex:
358-
"""
359-
Compute the intersection of the t0 datetimes available across all DataSources.
363+
def _get_t0_datetimes_across_all_data_sources(self) -> pd.DatetimeIndex:
364+
"""See DataSourceList.get_t0_datetimes_across_all_data_sources.
360365
361-
Returns the valid t0 datetimes, taking into consideration all DataSources,
362-
filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night
363-
datetimes).
366+
This method will be deleted as part of implementing #213.
364367
"""
365-
logger.debug("Get the intersection of time periods across all DataSources.")
366-
self._check_has_prepared_data()
367-
368-
# Get the intersection of t0 time periods from all data sources.
369-
t0_time_periods_for_all_data_sources = []
370-
for data_source in self.data_sources:
371-
logger.debug(f"Getting t0 time periods for {type(data_source).__name__}")
372-
try:
373-
t0_time_periods = data_source.get_contiguous_t0_time_periods()
374-
except NotImplementedError:
375-
pass # Skip data_sources with no concept of time.
376-
else:
377-
t0_time_periods_for_all_data_sources.append(t0_time_periods)
378-
379-
intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods(
380-
t0_time_periods_for_all_data_sources
381-
)
382-
383-
t0_datetimes = nd_time.time_periods_to_datetime_index(
384-
time_periods=intersection_of_t0_time_periods, freq=self.t0_datetime_freq
368+
return self.data_sources.get_t0_datetimes_across_all_data_sources(
369+
freq=self.t0_datetime_freq
385370
)
386371

387-
return t0_datetimes
388-
389372
def _check_has_prepared_data(self):
390373
if not self.has_prepared_data:
391374
raise RuntimeError("Must run prepare_data() first!")

nowcasting_dataset/dataset/datasets.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _get_batch(self) -> Batch:
141141
return []
142142

143143
t0_datetimes = self._get_t0_datetimes_for_batch()
144-
x_locations, y_locations = self._get_locations_for_batch(t0_datetimes)
144+
x_locations, y_locations = self._get_locations(t0_datetimes)
145145

146146
examples = {}
147147
n_threads = len(self.data_sources)
@@ -179,10 +179,8 @@ def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex:
179179
t0_datetimes = np.tile(t0_datetimes, reps=self.n_samples_per_timestep)
180180
return pd.DatetimeIndex(t0_datetimes)
181181

182-
def _get_locations_for_batch(
183-
self, t0_datetimes: pd.DatetimeIndex
184-
) -> Tuple[List[Number], List[Number]]:
185-
return self.data_sources[0].get_locations_for_batch(t0_datetimes)
182+
def _get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
183+
return self.data_sources[0].get_locations(t0_datetimes)
186184

187185

188186
def worker_init_fn(worker_id):

nowcasting_dataset/dataset/split/split.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from enum import Enum
55
from typing import List, Tuple, Union, Optional
6+
from collections import namedtuple
67

78
import pandas as pd
89

@@ -30,6 +31,20 @@ class SplitMethod(Enum):
3031
DAY_RANDOM_TEST_DATE = "day_random_test_date"
3132

3233

34+
class SplitName(Enum):
35+
"""The name for each data split."""
36+
37+
TRAIN = "train"
38+
VALIDATION = "validation"
39+
TEST = "test"
40+
41+
42+
SplitData = namedtuple(
43+
typename="SplitData",
44+
field_names=[SplitName.TRAIN.value, SplitName.VALIDATION.value, SplitName.TEST.value],
45+
)
46+
47+
3348
def split_data(
3449
datetimes: Union[List[pd.Timestamp], pd.DatetimeIndex],
3550
method: SplitMethod,
@@ -39,7 +54,7 @@ def split_data(
3954
),
4055
train_validation_test_datetime_split: Optional[List[pd.Timestamp]] = None,
4156
seed: int = 1234,
42-
) -> (List[pd.Timestamp], List[pd.Timestamp], List[pd.Timestamp]):
57+
) -> SplitData:
4358
"""
4459
Split the date using various different methods
4560
@@ -165,4 +180,10 @@ def split_data(
165180
else:
166181
raise ValueError(f"{method} for splitting day is not implemented")
167182

168-
return train_datetimes, validation_datetimes, test_datetimes
183+
logger.debug(
184+
f"Split data done, train has {len(train_datetimes):,d}, "
185+
f"validation has {len(validation_datetimes):,d}, "
186+
f"test has {len(test_datetimes):,d} t0 datetimes."
187+
)
188+
189+
return SplitData(train=train_datetimes, validation=validation_datetimes, test=test_datetimes)

0 commit comments

Comments
 (0)