Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 1cfd82b

Browse files
authored
Move BatchML/Datasets/Datamodule to Nowcasting-dataloader (#243)
* Move models to nowcasting-dataloader #major * Move tests * Fix import error * Move SatelliteML test * Fix import error * Remove batch test data * Move test data generation script * Move test generation script back * Update get_test_data.py
1 parent 5e482ce commit 1cfd82b

File tree

28 files changed

+22
-1564
lines changed

28 files changed

+22
-1564
lines changed

.github/workflows/python-app.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
name: Python Tests
55

6-
on: [push, pull_request]
6+
on: [push]
77

88
jobs:
99
build:

nowcasting_dataset/data_sources/datasource_output.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -53,69 +53,7 @@ def save_netcdf(self, batch_i: int, path: Path):
5353
self.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
5454

5555

56-
class DataSourceOutputML(BaseModel):
57-
"""General Data Source output pydantic class.
58-
59-
Data source output classes should inherit from this class
60-
"""
61-
62-
class Config:
63-
"""Allowed classes e.g. tensor.Tensor"""
64-
65-
# TODO maybe there is a better way to do this
66-
arbitrary_types_allowed = True
67-
68-
batch_size: int = Field(
69-
0,
70-
ge=0,
71-
description="The size of this batch. If the batch size is 0, "
72-
"then this item stores one data item i.e Example",
73-
)
74-
75-
def get_name(self) -> str:
76-
"""Get the name of the class"""
77-
return self.__class__.__name__.lower()
78-
79-
def get_datetime_index(self):
80-
"""Datetime index for the data"""
81-
pass
82-
83-
8456
def pad_nans(array, pad_width) -> np.ndarray:
8557
"""Pad nans with nans"""
8658
array = array.astype(np.float32)
8759
return np.pad(array, pad_width, constant_values=np.NaN)
88-
89-
90-
def pad_data(
91-
data: DataSourceOutputML,
92-
pad_size: int,
93-
one_dimensional_arrays: List[str],
94-
two_dimensional_arrays: List[str],
95-
):
96-
"""
97-
Pad (if necessary) so returned arrays are always of size
98-
99-
data has two types of arrays in it, one dimensional arrays and two dimensional arrays
100-
the one dimensional arrays are padded in that dimension
101-
the two dimensional arrays are padded in the second dimension
102-
103-
Note that class is edited so nothing is returned.
104-
105-
Args:
106-
data: typed dictionary of data objects
107-
pad_size: the maount that should be padded
108-
one_dimensional_arrays: list of data items that should be padded by one dimension
109-
two_dimensional_arrays: list of data tiems that should be padded in the third dimension (and more)
110-
111-
"""
112-
# Pad (if necessary) so returned arrays are always of size
113-
pad_shape = (0, pad_size) # (before, after)
114-
115-
for name in one_dimensional_arrays:
116-
data.__setattr__(name, pad_nans(data.__getattribute__(name), pad_width=pad_shape))
117-
118-
for variable in two_dimensional_arrays:
119-
data.__setattr__(
120-
variable, pad_nans(data.__getattribute__(variable), pad_width=((0, 0), pad_shape))
121-
) # (axis0, axis1)
Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
""" Model for output of datetime data """
2-
import numpy as np
3-
import xarray as xr
4-
from pydantic import validator
5-
6-
from nowcasting_dataset.consts import Array, DATETIME_FEATURE_NAMES
72
from nowcasting_dataset.data_sources.datasource_output import (
8-
DataSourceOutputML,
93
DataSourceOutput,
104
)
115

@@ -21,96 +15,3 @@ class Datetime(DataSourceOutput):
2115

2216
# todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233
2317
_expected_dimensions = ("time",)
24-
25-
26-
class DatetimeML(DataSourceOutputML):
27-
""" Model for output of datetime data """
28-
29-
hour_of_day_sin: Array #: Shape: [batch_size,] seq_length
30-
hour_of_day_cos: Array
31-
day_of_year_sin: Array
32-
day_of_year_cos: Array
33-
datetime_index: Array
34-
35-
@property
36-
def sequence_length(self):
37-
"""The sequence length of the pv data"""
38-
return self.hour_of_day_sin.shape[-1]
39-
40-
@validator("hour_of_day_cos")
41-
def v_hour_of_day_cos(cls, v, values):
42-
""" Validate 'hour_of_day_cos' """
43-
assert v.shape[-1] == values["hour_of_day_sin"].shape[-1]
44-
return v
45-
46-
@validator("day_of_year_sin")
47-
def v_day_of_year_sin(cls, v, values):
48-
""" Validate 'day_of_year_sin' """
49-
assert v.shape[-1] == values["hour_of_day_sin"].shape[-1]
50-
return v
51-
52-
@validator("day_of_year_cos")
53-
def v_day_of_year_cos(cls, v, values):
54-
""" Validate 'day_of_year_cos' """
55-
assert v.shape[-1] == values["hour_of_day_sin"].shape[-1]
56-
return v
57-
58-
@staticmethod
59-
def fake(batch_size, seq_length_5):
60-
""" Make a fake Datetime object """
61-
return DatetimeML(
62-
batch_size=batch_size,
63-
hour_of_day_sin=np.random.randn(
64-
batch_size,
65-
seq_length_5,
66-
),
67-
hour_of_day_cos=np.random.randn(
68-
batch_size,
69-
seq_length_5,
70-
),
71-
day_of_year_sin=np.random.randn(
72-
batch_size,
73-
seq_length_5,
74-
),
75-
day_of_year_cos=np.random.randn(
76-
batch_size,
77-
seq_length_5,
78-
),
79-
datetime_index=np.sort(np.random.randn(batch_size, seq_length_5))[:, ::-1].copy(),
80-
# copy is needed as torch doesnt not support negative strides
81-
)
82-
83-
def to_xr_dataset(self, _):
84-
""" Make a xr dataset """
85-
individual_datasets = []
86-
for name in DATETIME_FEATURE_NAMES:
87-
88-
var = self.__getattribute__(name)
89-
90-
data = xr.DataArray(
91-
var,
92-
dims=["time"],
93-
coords={"time": self.datetime_index},
94-
name=name,
95-
)
96-
97-
ds = data.to_dataset()
98-
ds = coord_to_range(ds, "time", prefix=None)
99-
individual_datasets.append(ds)
100-
101-
return xr.merge(individual_datasets)
102-
103-
@staticmethod
104-
def from_xr_dataset(xr_dataset):
105-
""" Change xr dataset to model. If data does not exist, then return None """
106-
if "hour_of_day_sin" in xr_dataset.keys():
107-
return DatetimeML(
108-
batch_size=xr_dataset["hour_of_day_sin"].shape[0],
109-
hour_of_day_sin=xr_dataset["hour_of_day_sin"],
110-
hour_of_day_cos=xr_dataset["hour_of_day_cos"],
111-
day_of_year_sin=xr_dataset["day_of_year_sin"],
112-
day_of_year_cos=xr_dataset["day_of_year_cos"],
113-
datetime_index=xr_dataset["hour_of_day_sin"].time,
114-
)
115-
else:
116-
return None
Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
11
""" Model for output of GSP data """
22
import logging
33

4-
import numpy as np
5-
from pydantic import Field, validator
6-
7-
from nowcasting_dataset.consts import Array
8-
from nowcasting_dataset.consts import (
9-
GSP_ID,
10-
GSP_YIELD,
11-
GSP_X_COORDS,
12-
GSP_Y_COORDS,
13-
GSP_DATETIME_INDEX,
14-
)
154
from nowcasting_dataset.data_sources.datasource_output import (
16-
DataSourceOutputML,
175
DataSourceOutput,
186
)
197
from nowcasting_dataset.time import make_random_time_vectors
@@ -28,109 +16,3 @@ class GSP(DataSourceOutput):
2816
_expected_dimensions = ("time", "id")
2917

3018
# todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233
31-
32-
33-
class GSPML(DataSourceOutputML):
34-
""" Model for output of GSP data """
35-
36-
# Shape: [batch_size,] seq_length, width, height, channel
37-
gsp_yield: Array = Field(
38-
...,
39-
description=" GSP yield from all GSP in the region of interest (ROI). \
40-
: Includes central GSP system, which will always be the first entry. \
41-
: shape = [batch_size, ] seq_length, n_gsp_per_example",
42-
)
43-
44-
#: GSP identification.
45-
#: shape = [batch_size, ] n_pv_systems_per_example
46-
gsp_id: Array = Field(..., description="gsp id from NG")
47-
48-
gsp_datetime_index: Array = Field(
49-
...,
50-
description="The datetime associated with the gsp data. shape = [batch_size, ] sequence length,",
51-
)
52-
53-
gsp_x_coords: Array = Field(
54-
...,
55-
description="The x (OSGB geo-spatial) coordinates of the gsp. "
56-
"This is in fact the x centroid of the GSP region"
57-
"Shape: [batch_size,] n_gsp_per_example",
58-
)
59-
gsp_y_coords: Array = Field(
60-
...,
61-
description="The y (OSGB geo-spatial) coordinates of the gsp. "
62-
"This are in fact the y centroid of the GSP region"
63-
"Shape: [batch_size,] n_gsp_per_example",
64-
)
65-
66-
@property
67-
def number_of_gsp(self):
68-
"""The number of Grid Supply Points in this example"""
69-
return self.gsp_yield.shape[-1]
70-
71-
@property
72-
def sequence_length(self):
73-
"""The sequence length of the GSP PV power timeseries data"""
74-
return self.gsp_yield.shape[-2]
75-
76-
@validator("gsp_yield")
77-
def gsp_yield_shape(cls, v, values):
78-
""" Validate 'gsp_yield' """
79-
if values["batch_size"] > 0:
80-
assert len(v.shape) == 3
81-
else:
82-
assert len(v.shape) == 2
83-
return v
84-
85-
@validator("gsp_x_coords")
86-
def x_coordinates_shape(cls, v, values):
87-
""" Validate 'gsp_x_coords' """
88-
assert v.shape[-1] == values["gsp_yield"].shape[-1]
89-
return v
90-
91-
@validator("gsp_y_coords")
92-
def y_coordinates_shape(cls, v, values):
93-
""" Validate 'gsp_y_coords' """
94-
assert v.shape[-1] == values["gsp_yield"].shape[-1]
95-
return v
96-
97-
@staticmethod
98-
def fake(batch_size, seq_length_30, n_gsp_per_batch, time_30=None):
99-
""" Make a fake GSP object """
100-
if time_30 is None:
101-
_, _, time_30 = make_random_time_vectors(
102-
batch_size=batch_size, seq_length_5_minutes=0, seq_length_30_minutes=seq_length_30
103-
)
104-
105-
return GSPML(
106-
batch_size=batch_size,
107-
gsp_yield=np.random.randn(
108-
batch_size,
109-
seq_length_30,
110-
n_gsp_per_batch,
111-
),
112-
gsp_id=np.sort(np.random.randint(0, 340, (batch_size, n_gsp_per_batch))),
113-
gsp_datetime_index=time_30,
114-
gsp_x_coords=np.sort(np.random.randn(batch_size, n_gsp_per_batch)),
115-
gsp_y_coords=np.sort(np.random.randn(batch_size, n_gsp_per_batch))[:, ::-1].copy(),
116-
)
117-
# copy is needed as torch doesnt not support negative strides
118-
119-
def get_datetime_index(self) -> Array:
120-
""" Get the datetime index of this data """
121-
return self.gsp_datetime_index
122-
123-
@staticmethod
124-
def from_xr_dataset(xr_dataset):
125-
""" Change xr dataset to model. If data does not exist, then return None """
126-
if "gsp_yield" in xr_dataset.keys():
127-
return GSPML(
128-
batch_size=xr_dataset["gsp_yield"].shape[0],
129-
gsp_yield=xr_dataset[GSP_YIELD],
130-
gsp_id=xr_dataset[GSP_ID],
131-
gsp_datetime_index=xr_dataset[GSP_DATETIME_INDEX],
132-
gsp_x_coords=xr_dataset[GSP_X_COORDS],
133-
gsp_y_coords=xr_dataset[GSP_Y_COORDS],
134-
)
135-
else:
136-
return None
Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
""" Model for output of general/metadata data, useful for a batch """
22
from typing import Union
33

4-
import numpy as np
5-
import torch
6-
import xarray as xr
7-
from pydantic import Field
8-
94
from nowcasting_dataset.data_sources.datasource_output import (
10-
DataSourceOutputML,
115
DataSourceOutput,
126
)
137

@@ -24,50 +18,3 @@ class Metadata(DataSourceOutput):
2418
_expected_dimensions = ("t0_dt",)
2519

2620
# todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233
27-
28-
29-
class MetadataML(DataSourceOutputML):
30-
"""Model for output of general/metadata data"""
31-
32-
# TODO add descriptions
33-
t0_dt: Union[xr.DataArray, np.ndarray, torch.Tensor, int] #: Shape: [batch_size,]
34-
x_meters_center: Union[xr.DataArray, np.ndarray, torch.Tensor, int]
35-
y_meters_center: Union[xr.DataArray, np.ndarray, torch.Tensor, int]
36-
object_at_center_label: Union[xr.DataArray, np.ndarray, torch.Tensor, int] = Field(
37-
...,
38-
description="What object is at the center of the batch data "
39-
"0: Nothing at the center, "
40-
"1: GSP system, "
41-
"2: PV system",
42-
)
43-
44-
@staticmethod
45-
def fake(batch_size, t0_dt=None):
46-
"""Make a xr dataset"""
47-
if t0_dt is None:
48-
t0_dt, _, _ = make_random_time_vectors(
49-
batch_size=batch_size, seq_length_5_minutes=0, seq_length_30_minutes=0
50-
)
51-
52-
return MetadataML(
53-
batch_size=batch_size,
54-
t0_dt=t0_dt,
55-
x_meters_center=np.random.randn(
56-
batch_size,
57-
),
58-
y_meters_center=np.random.randn(
59-
batch_size,
60-
),
61-
object_at_center_label=np.array([1] * batch_size),
62-
)
63-
64-
@staticmethod
65-
def from_xr_dataset(xr_dataset):
66-
"""Change xr dataset to model. If data does not exist, then return None"""
67-
return MetadataML(
68-
batch_size=xr_dataset.t0_dt.shape[0],
69-
t0_dt=xr_dataset.t0_dt.values,
70-
x_meters_center=xr_dataset.x_meters_center.values,
71-
y_meters_center=xr_dataset.y_meters_center.values,
72-
object_at_center_label=xr_dataset.object_at_center_label.values,
73-
)

0 commit comments

Comments
 (0)