Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit d577bb0

Browse files
Merge pull request #383 from openclimatefix/issue/318-dims-index
Issue/318 dims index
2 parents 91c971b + a30e506 commit d577bb0

File tree

12 files changed

+137
-116
lines changed

12 files changed

+137
-116
lines changed

nowcasting_dataset/data_sources/README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ and the geospatial shape of each GSP region).
1515
# data_source.py
1616

1717
General class used for making a data source. It has the following functions
18-
- get_batch: gets a whole batch of data for that data source
18+
- get_batch: gets a whole batch of data for that data source. The list of 'xr.Dataset' examples are converted to
19+
one xr.Dataset by changing the coordinates to indexes, and then joining the examples along an extra dimension.
1920
- datetime_index: gets the all available datatimes of the source
2021
- get_example: gets one "example" (a single consecutive sequence). Each batch is made up of multiple examples.
22+
Each example is a 'xr.Dataset'
2123
- get_locations_for_batch: Samples the geospatial x,y location for each example in a batch. This is useful because,
2224
typically, we want a single DataSource to dictate the geospatial locations of the examples (for example,
2325
we want each example to be centered on the centroid of the grid supply point region). All the other
@@ -27,20 +29,19 @@ General class used for making a data source. It has the following functions
2729
# datasource_output.py
2830

2931
General pydantic model of output of the data source. Contains the following methods
30-
- to_numpy: changes all data points to numpy objects
31-
- split: converts a batch to a list of items
32-
- join: joins list of items to one
33-
- to_xr_dataset: changes data items to xarrays and returns a dataset
34-
- from_xr_dataset: loads from an xarray dataset
35-
- select_time_period: subselect data, depending on a time period
32+
- save_netcdf: save to netcdf file
33+
- check_nan_and_inf: check if any values are nans or infinite
34+
- check_dataset_greater_than_or_equal_to: check values are >= a value
35+
- check_dataset_less_than_or_equal_to: check values are <= a value
36+
- check_dataset_not_equal: check values are !>= a value
37+
- check_data_var_dim: check the dimensions of a data variable
3638

3739
# <X> Data Source folder
3840

3941
Roughly each of the data source folders follows this pattern
4042
- A class which defines how to load the data source, how to select for batches etc. This inherits from 'data_source.DataSource',
41-
- A class which contains the output model of the data source, built from an xarray Dataset. This is the information used in the batches.
43+
- A class which contains the output model of the data source, built from a xarray Dataset. This is the information used in the batches.
4244
This inherits from 'datasource_output.DataSourceOutput'.
43-
- A second class (pydantic) which moves the xarray Dataset to tensor fields. This will be used for training in ML models
4445

4546

4647
# fake

nowcasting_dataset/data_sources/data_source.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
from nowcasting_dataset import square
1717
from nowcasting_dataset.consts import SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES
1818
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
19-
from nowcasting_dataset.dataset.xr_utils import join_list_dataset_to_batch_dataset, make_dim_index
19+
from nowcasting_dataset.dataset.xr_utils import (
20+
convert_coordinates_to_indexes_for_list_datasets,
21+
join_list_dataset_to_batch_dataset,
22+
)
2023

2124
logger = logging.getLogger(__name__)
2225

@@ -257,10 +260,10 @@ def get_batch(
257260
examples = [future_example.result() for future_example in future_examples]
258261

259262
# Get the DataSource class, this could be one of the data sources like Sun
260-
cls = examples[0].__class__
263+
cls = self.get_data_model_for_batch()
261264

262265
# Set the coords to be indices before joining into a batch
263-
examples = [make_dim_index(example) for example in examples]
266+
examples = convert_coordinates_to_indexes_for_list_datasets(examples)
264267

265268
# join the examples together, and cast them to the cls, so that validation can occur
266269
return cls(join_list_dataset_to_batch_dataset(examples))
@@ -271,6 +274,10 @@ def datetime_index(self) -> pd.DatetimeIndex:
271274
# of a list of datetimes (e.g. for DatetimeDataSource).
272275
raise NotImplementedError()
273276

277+
def get_data_model_for_batch(self):
278+
"""Get the model that is used in the batch"""
279+
raise NotImplementedError()
280+
274281
def get_contiguous_time_periods(self) -> pd.DataFrame:
275282
"""Get all the time periods for which this DataSource has contiguous data.
276283
@@ -378,7 +385,7 @@ def data(self):
378385

379386
def get_example(
380387
self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number
381-
) -> DataSourceOutput:
388+
) -> xr.Dataset:
382389
"""
383390
Get Example data
384391
@@ -419,7 +426,7 @@ def get_example(
419426
f"actual shape {selected_data.shape}"
420427
)
421428

422-
return selected_data.load()
429+
return selected_data.load().to_dataset(name="data")
423430

424431
def geospatial_border(self) -> List[Tuple[Number, Number]]:
425432
"""

nowcasting_dataset/data_sources/datasource_output.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ def check_nan_and_inf(self, data: xr.Dataset, variable_name: str = None):
5757

5858
if np.isnan(data).any():
5959
message = f"Some {self.__class__.__name__} data values are NaNs"
60-
message += f" ({variable_name})" if variable_name is not None else None
60+
if variable_name is not None:
61+
message += f" ({variable_name})"
6162
logger.error(message)
6263
raise Exception(message)
6364

6465
if np.isinf(data).any():
6566
message = f"Some {self.__class__.__name__} data values are Infinite"
66-
message += f" ({variable_name})" if variable_name is not None else None
67+
if variable_name is not None:
68+
message += f" ({variable_name})"
6769
logger.error(message)
6870
raise Exception(message)
6971

nowcasting_dataset/data_sources/fake.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
Wanted to keep this out of the testing frame works, as other repos, might want to use this
44
"""
5+
from typing import List
6+
57
import numpy as np
68
import pandas as pd
79
import xarray as xr
@@ -15,8 +17,8 @@
1517
from nowcasting_dataset.data_sources.sun.sun_model import Sun
1618
from nowcasting_dataset.data_sources.topographic.topographic_model import Topographic
1719
from nowcasting_dataset.dataset.xr_utils import (
18-
convert_data_array_to_dataset,
19-
join_list_data_array_to_batch_dataset,
20+
convert_coordinates_to_indexes,
21+
convert_coordinates_to_indexes_for_list_datasets,
2022
join_list_dataset_to_batch_dataset,
2123
)
2224

@@ -28,7 +30,7 @@ def gsp_fake(
2830
):
2931
"""Create fake data"""
3032
# make batch of arrays
31-
xr_arrays = [
33+
xr_datasets = [
3234
create_gsp_pv_dataset(
3335
seq_length=seq_length_30,
3436
freq="30T",
@@ -37,8 +39,11 @@ def gsp_fake(
3739
for _ in range(batch_size)
3840
]
3941

42+
# change dimensions to dimension indexes
43+
xr_datasets = convert_coordinates_to_indexes_for_list_datasets(xr_datasets)
44+
4045
# make dataset
41-
xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays)
46+
xr_dataset = join_list_dataset_to_batch_dataset(xr_datasets)
4247

4348
return GSP(xr_dataset)
4449

@@ -47,6 +52,9 @@ def metadata_fake(batch_size):
4752
"""Make a xr dataset"""
4853
xr_arrays = [create_metadata_dataset() for _ in range(batch_size)]
4954

55+
# change to indexes
56+
xr_arrays = [convert_coordinates_to_indexes(xr_array) for xr_array in xr_arrays]
57+
5058
# make dataset
5159
xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays)
5260

@@ -81,7 +89,7 @@ def nwp_fake(
8189
def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch):
8290
"""Create fake data"""
8391
# make batch of arrays
84-
xr_arrays = [
92+
xr_datasets = [
8593
create_gsp_pv_dataset(
8694
seq_length=seq_length_5,
8795
freq="5T",
@@ -90,8 +98,11 @@ def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch):
9098
for _ in range(batch_size)
9199
]
92100

101+
# change dimensions to dimension indexes
102+
xr_datasets = convert_coordinates_to_indexes_for_list_datasets(xr_datasets)
103+
93104
# make dataset
94-
xr_dataset = join_list_dataset_to_batch_dataset(xr_arrays)
105+
xr_dataset = join_list_dataset_to_batch_dataset(xr_datasets)
95106

96107
return PV(xr_dataset)
97108

@@ -150,6 +161,7 @@ def topographic_fake(batch_size, image_size_pixels):
150161
x=np.sort(np.random.randn(image_size_pixels)),
151162
y=np.sort(np.random.randn(image_size_pixels))[::-1].copy(),
152163
),
164+
name="data",
153165
)
154166
for _ in range(batch_size)
155167
]
@@ -184,6 +196,7 @@ def create_image_array(
184196
)
185197
),
186198
coords=coords,
199+
name="data",
187200
) # Fake data for testing!
188201
return image_data_array
189202

@@ -197,7 +210,7 @@ def create_gsp_pv_dataset(
197210
"""Create gsp or pv fake dataset"""
198211
ALL_COORDS = {
199212
"time": pd.date_range("2021-01-01", freq=freq, periods=seq_length),
200-
"id": np.random.randint(low=0, high=1000, size=number_of_systems),
213+
"id": np.random.choice(range(1000), number_of_systems, replace=False),
201214
}
202215
coords = [(dim, ALL_COORDS[dim]) for dim in dims]
203216
data_array = xr.DataArray(
@@ -208,22 +221,20 @@ def create_gsp_pv_dataset(
208221
coords=coords,
209222
) # Fake data for testing!
210223

211-
data = convert_data_array_to_dataset(data_array)
224+
data = data_array.to_dataset(name="data")
212225

213226
x_coords = xr.DataArray(
214-
data=np.sort(np.random.randn(number_of_systems)),
215-
dims=["id_index"],
216-
coords=dict(
217-
id_index=range(number_of_systems),
227+
data=np.sort(
228+
np.random.choice(range(2 * number_of_systems), number_of_systems, replace=False)
218229
),
230+
dims=["id"],
219231
)
220232

221233
y_coords = xr.DataArray(
222-
data=np.sort(np.random.randn(number_of_systems)),
223-
dims=["id_index"],
224-
coords=dict(
225-
id_index=range(number_of_systems),
234+
data=np.sort(
235+
np.random.choice(range(2 * number_of_systems), number_of_systems, replace=False)
226236
),
237+
dims=["id"],
227238
)
228239

229240
data["x_coords"] = x_coords
@@ -265,13 +276,14 @@ def create_sun_dataset(
265276
coords=coords,
266277
) # Fake data for testing!
267278

268-
data = convert_data_array_to_dataset(data_array)
269-
sun = data.rename({"data": "elevation"})
270-
sun["azimuth"] = data.data
279+
sun = data_array.to_dataset(name="elevation")
280+
sun["azimuth"] = sun.elevation
271281

272282
sun.__setitem__("azimuth", sun.azimuth.clip(min=0, max=360))
273283
sun.__setitem__("elevation", sun.elevation.clip(min=-90, max=90))
274284

285+
sun = convert_coordinates_to_indexes(sun)
286+
275287
return sun
276288

277289

@@ -282,11 +294,11 @@ def create_metadata_dataset() -> xr.Dataset:
282294
"data": pd.date_range("2021-01-01", freq="5T", periods=1) + pd.Timedelta("30T"),
283295
}
284296

285-
data = convert_data_array_to_dataset(xr.DataArray.from_dict(d))
297+
data = (xr.DataArray.from_dict(d)).to_dataset(name="data")
286298

287299
for v in ["x_meters_center", "y_meters_center", "object_at_center_label"]:
288300
d: dict = {"dims": ("t0_dt",), "data": [np.random.randint(0, 1000)]}
289-
d: xr.Dataset = convert_data_array_to_dataset(xr.DataArray.from_dict(d)).rename({"data": v})
301+
d: xr.Dataset = (xr.DataArray.from_dict(d)).to_dataset(name=v)
290302
data[v] = getattr(d, v)
291303

292304
return data
@@ -307,11 +319,20 @@ def create_datetime_dataset(
307319
coords=coords,
308320
) # Fake data
309321

310-
data = convert_data_array_to_dataset(data_array)
322+
data = data_array.to_dataset()
311323

312324
ds = data.rename({"data": "day_of_year_cos"})
313325
ds["day_of_year_sin"] = data.rename({"data": "day_of_year_sin"}).day_of_year_sin
314326
ds["hour_of_day_cos"] = data.rename({"data": "hour_of_day_cos"}).hour_of_day_cos
315327
ds["hour_of_day_sin"] = data.rename({"data": "hour_of_day_sin"}).hour_of_day_sin
316328

317329
return data
330+
331+
332+
def join_list_data_array_to_batch_dataset(data_arrays: List[xr.DataArray]) -> xr.Dataset:
333+
"""Join a list of xr.DataArrays into an xr.Dataset by concatenating on the example dim."""
334+
datasets = [
335+
convert_coordinates_to_indexes(data_arrays[i].to_dataset()) for i in range(len(data_arrays))
336+
]
337+
338+
return join_list_dataset_to_batch_dataset(datasets)

nowcasting_dataset/data_sources/gsp/gsp_data_source.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from nowcasting_dataset.data_sources.data_source import ImageDataSource
1919
from nowcasting_dataset.data_sources.gsp.eso import get_gsp_metadata_from_eso
2020
from nowcasting_dataset.data_sources.gsp.gsp_model import GSP
21-
from nowcasting_dataset.dataset.xr_utils import convert_data_array_to_dataset
2221
from nowcasting_dataset.geospatial import lat_lon_to_osgb
2322
from nowcasting_dataset.square import get_bounding_box_mask
2423
from nowcasting_dataset.utils import scale_to_0_to_1
@@ -73,6 +72,10 @@ def sample_period_minutes(self) -> int:
7372
"""Override the default sample minutes"""
7473
return 30
7574

75+
def get_data_model_for_batch(self):
76+
"""Get the model that is used in the batch"""
77+
return GSP
78+
7679
def load(self):
7780
"""
7881
Load the meta data and load the GSP power data
@@ -153,7 +156,7 @@ def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], L
153156

154157
def get_example(
155158
self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number
156-
) -> GSP:
159+
) -> xr.Dataset:
157160
"""
158161
Get data example from one time point (t0_dt) and for x and y coords.
159162
@@ -201,41 +204,31 @@ def get_example(
201204
da = xr.DataArray(
202205
data=selected_gsp_power.values,
203206
dims=["time", "id"],
204-
coords=dict(
205-
id=all_gsp_ids.values.astype(int),
206-
time=selected_gsp_power.index.values,
207-
),
208207
)
209208

210209
# convert to dataset
211-
gsp = convert_data_array_to_dataset(da)
210+
gsp = da.to_dataset(name="data")
212211

213212
# add gsp x coords
214213
gsp_x_coords = xr.DataArray(
215214
data=gsp_x_coords.values,
216-
dims=["id_index"],
217-
coords=dict(
218-
id_index=range(len(all_gsp_ids.values)),
219-
),
215+
dims=["id"],
220216
)
221217

222218
gsp_y_coords = xr.DataArray(
223219
data=gsp_y_coords.values,
224-
dims=["id_index"],
225-
coords=dict(
226-
id_index=range(len(all_gsp_ids.values)),
227-
),
220+
dims=["id"],
228221
)
229222
gsp["x_coords"] = gsp_x_coords
230223
gsp["y_coords"] = gsp_y_coords
231224

232225
# pad out so that there are always 32 gsp, fill with 0
233-
pad_n = self.n_gsp_per_example - len(gsp.id_index)
234-
gsp = gsp.pad(id_index=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0)
226+
pad_n = self.n_gsp_per_example - len(gsp.id)
227+
gsp = gsp.pad(id=(0, pad_n), data=((0, 0), (0, pad_n)), constant_values=0)
235228

236-
gsp.__setitem__("id_index", range(self.n_gsp_per_example))
229+
gsp.__setitem__("id", range(self.n_gsp_per_example))
237230

238-
return GSP(gsp)
231+
return gsp
239232

240233
def _get_central_gsp_id(
241234
self,

0 commit comments

Comments
 (0)