Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit f54f914

Browse files
authored
Merge pull request #284 from openclimatefix/jack/allow-input-data-fields-to-be-None
Allow InputData fields to be None. Also remove some unused imports
2 parents 974308e + d28bc32 commit f54f914

File tree

8 files changed

+63
-113
lines changed

8 files changed

+63
-113
lines changed

nowcasting_dataset/config/model.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class PV(DataSourceMixin):
7979
"gs://solar-pv-nowcasting-data/PV/PVOutput.org/UK_PV_metadata.csv",
8080
description="The CSV file describing each PV system.",
8181
)
82-
n_gsp_per_example: int = Field(
82+
n_pv_systems_per_example: int = Field(
8383
DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE,
8484
description="The number of PV systems samples per example. "
8585
"If there are less in the ROI then the data is padded with zeros. ",
@@ -160,12 +160,12 @@ class InputData(BaseModel):
160160
Input data model.
161161
"""
162162

163-
pv: PV = PV()
164-
satellite: Satellite = Satellite()
165-
nwp: NWP = NWP()
166-
gsp: GSP = GSP()
167-
topographic: Topographic = Topographic()
168-
sun: Sun = Sun()
163+
pv: Optional[PV] = None
164+
satellite: Optional[Satellite] = None
165+
nwp: Optional[NWP] = None
166+
gsp: Optional[GSP] = None
167+
topographic: Optional[Topographic] = None
168+
sun: Optional[Sun] = None
169169

170170
default_forecast_minutes: int = Field(
171171
60,
@@ -194,8 +194,14 @@ def set_forecast_and_history_minutes(cls, values):
194194
then set them to the default values
195195
"""
196196

197-
for data_source_name in ["pv", "nwp", "satellite", "gsp", "topographic", "sun"]:
197+
ALL_DATA_SOURCE_NAMES = ("pv", "satellite", "nwp", "gsp", "topographic", "sun")
198+
enabled_data_sources = [
199+
data_source_name
200+
for data_source_name in ALL_DATA_SOURCE_NAMES
201+
if values[data_source_name] is not None
202+
]
198203

204+
for data_source_name in enabled_data_sources:
199205
if values[data_source_name].forecast_minutes is None:
200206
values[data_source_name].forecast_minutes = values["default_forecast_minutes"]
201207

@@ -204,6 +210,21 @@ def set_forecast_and_history_minutes(cls, values):
204210

205211
return values
206212

213+
@classmethod
214+
def set_all_to_defaults(cls):
215+
"""Returns an InputData instance with all fields set to their default values.
216+
217+
Used for unittests.
218+
"""
219+
return cls(
220+
pv=PV(),
221+
satellite=Satellite(),
222+
nwp=NWP(),
223+
gsp=GSP(),
224+
topographic=Topographic(),
225+
sun=Sun(),
226+
)
227+
207228

208229
class OutputData(BaseModel):
209230
"""Output data model"""
Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,5 @@
11
""" Model for output of general/metadata data, useful for a batch """
2-
from typing import Union
3-
4-
from nowcasting_dataset.data_sources.datasource_output import (
5-
DataSourceOutput,
6-
)
7-
8-
from nowcasting_dataset.time import make_random_time_vectors
9-
10-
11-
# seems to be a pandas dataseries
2+
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
123

134

145
class Metadata(DataSourceOutput):
@@ -17,4 +8,4 @@ class Metadata(DataSourceOutput):
178
__slots__ = ()
189
_expected_dimensions = ("t0_dt",)
1910

20-
# todo add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233
11+
# TODO: Add validation here - https://github.com/openclimatefix/nowcasting_dataset/issues/233
Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,18 @@
11
""" Model for output of PV data """
2-
import logging
3-
4-
import numpy as np
52
from xarray.ufuncs import isnan, isinf
6-
from pydantic import Field, validator
7-
8-
from nowcasting_dataset.consts import (
9-
Array,
10-
PV_YIELD,
11-
PV_DATETIME_INDEX,
12-
PV_SYSTEM_Y_COORDS,
13-
PV_SYSTEM_X_COORDS,
14-
PV_SYSTEM_ROW_NUMBER,
15-
PV_SYSTEM_ID,
16-
)
17-
from nowcasting_dataset.data_sources.datasource_output import (
18-
DataSourceOutput,
19-
)
20-
from nowcasting_dataset.time import make_random_time_vectors
21-
22-
logger = logging.getLogger(__name__)
3+
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
234

245

256
class PV(DataSourceOutput):
267
""" Class to store PV data as a xr.Dataset with some validation """
278

28-
# Use to store xr.Dataset data
29-
309
__slots__ = ()
3110
_expected_dimensions = ("time", "id")
3211

3312
@classmethod
3413
def model_validation(cls, v):
35-
""" Check that all values are non NaNs """
36-
assert (~isnan(v.data)).all(), f"Some pv data values are NaNs"
37-
assert (~isinf(v.data)).all(), f"Some pv data values are Infinite"
38-
39-
assert (v.data >= 0).all(), f"Some pv data values are below 0"
40-
14+
""" Check that all values are not Nan, Infinite, or < 0."""
15+
assert (~isnan(v.data)).all(), "Some pv data values are NaNs"
16+
assert (~isinf(v.data)).all(), "Some pv data values are Infinite"
17+
assert (v.data >= 0).all(), "Some pv data values are below 0"
4118
return v
Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,19 @@
11
""" Model for output of satellite data """
22
from __future__ import annotations
3-
4-
import logging
5-
6-
import numpy as np
73
from xarray.ufuncs import isnan, isinf
8-
from pydantic import Field
9-
10-
from nowcasting_dataset.consts import Array
11-
from nowcasting_dataset.data_sources.datasource_output import (
12-
DataSourceOutput,
13-
)
14-
from nowcasting_dataset.time import make_random_time_vectors
15-
16-
logger = logging.getLogger(__name__)
4+
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
175

186

197
class Satellite(DataSourceOutput):
208
""" Class to store satellite data as a xr.Dataset with some validation """
219

22-
# Use to store xr.Dataset data
23-
2410
__slots__ = ()
2511
_expected_dimensions = ("time", "x", "y", "channels")
2612

2713
@classmethod
2814
def model_validation(cls, v):
29-
""" Check that all values are non negative """
30-
assert (~isnan(v.data)).all(), f"Some satellite data values are NaNs"
31-
assert (~isinf(v.data)).all(), f"Some satellite data values are Infinite"
32-
assert (v.data != -1).all(), f"Some satellite data values are -1's"
15+
""" Check that all values are not NaN, Infinite, or -1."""
16+
assert (~isnan(v.data)).all(), "Some satellite data values are NaNs"
17+
assert (~isinf(v.data)).all(), "Some satellite data values are Infinite"
18+
assert (v.data != -1).all(), "Some satellite data values are -1's"
3319
return v

nowcasting_dataset/data_sources/sun/sun_model.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,22 @@
11
""" Model for Sun features """
2-
import logging
3-
4-
import numpy as np
52
from xarray.ufuncs import isnan, isinf
6-
from pydantic import Field, validator
7-
8-
from nowcasting_dataset.consts import Array, SUN_AZIMUTH_ANGLE, SUN_ELEVATION_ANGLE
9-
from nowcasting_dataset.data_sources.datasource_output import (
10-
DataSourceOutput,
11-
)
12-
from nowcasting_dataset.time import make_random_time_vectors
13-
14-
logger = logging.getLogger(__name__)
3+
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
154

165

176
class Sun(DataSourceOutput):
187
""" Class to store Sun data as a xr.Dataset with some validation """
198

20-
# Use to store xr.Dataset data
219
__slots__ = ()
2210
_expected_dimensions = ("time",)
2311

2412
@classmethod
2513
def model_validation(cls, v):
2614
""" Check that all values are non NaNs """
27-
assert (~isnan(v.elevation)).all(), f"Some elevation data values are NaNs"
28-
assert (~isinf(v.elevation)).all(), f"Some elevation data values are Infinite"
15+
assert (~isnan(v.elevation)).all(), "Some elevation data values are NaNs"
16+
assert (~isinf(v.elevation)).all(), "Some elevation data values are Infinite"
2917

30-
assert (~isnan(v.azimuth)).all(), f"Some azimuth data values are NaNs"
31-
assert (~isinf(v.azimuth)).all(), f"Some azimuth data values are Infinite"
18+
assert (~isnan(v.azimuth)).all(), "Some azimuth data values are NaNs"
19+
assert (~isinf(v.azimuth)).all(), "Some azimuth data values are Infinite"
3220

3321
assert (0 <= v.azimuth).all(), f"Some azimuth data values are lower 0, {v.azimuth.min()}"
3422
assert (
Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,17 @@
11
""" Model for Topogrpahic features """
2-
import logging
3-
4-
import numpy as np
52
from xarray.ufuncs import isnan, isinf
6-
from pydantic import Field, validator
7-
8-
from nowcasting_dataset.consts import Array
9-
from nowcasting_dataset.consts import TOPOGRAPHIC_DATA
103
from nowcasting_dataset.data_sources.datasource_output import DataSourceOutput
114

12-
logger = logging.getLogger(__name__)
13-
145

156
class Topographic(DataSourceOutput):
167
""" Class to store topographic data as a xr.Dataset with some validation """
178

18-
# Use to store xr.Dataset data
199
__slots__ = ()
2010
_expected_dimensions = ("x", "y")
2111

2212
@classmethod
2313
def model_validation(cls, v):
2414
""" Check that all values are non NaNs """
25-
assert (~isnan(v.data)).all(), f"Some topological data values are NaNs"
26-
assert (~isinf(v.data)).all(), f"Some topological data values are Infinite"
15+
assert (~isnan(v.data)).all(), "Some topological data values are NaNs"
16+
assert (~isinf(v.data)).all(), "Some topological data values are Infinite"
2717
return v

nowcasting_dataset/dataset/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def data_sources(self):
8787
]
8888

8989
@staticmethod
90-
def fake(configuration: Configuration = Configuration()):
90+
def fake(configuration: Configuration):
9191
""" Make fake batch object """
9292
batch_size = configuration.process.batch_size
93-
satellite_image_size_pixels = configuration.input_data.satellite.satellite_image_size_pixels
94-
nwp_image_size_pixels = configuration.input_data.nwp.nwp_image_size_pixels
93+
satellite_image_size_pixels = 64
94+
nwp_image_size_pixels = 64
9595

9696
return Batch(
9797
batch_size=batch_size,

tests/dataset/test_batch.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,32 @@
1+
import pytest
12
import tempfile
23
import os
3-
from nowcasting_dataset.config.model import Configuration
4+
from nowcasting_dataset.config.model import Configuration, InputData
45
from nowcasting_dataset.dataset.batch import Batch
56

67

7-
def test_model():
8-
8+
@pytest.fixture
9+
def configuration():
910
con = Configuration()
11+
con.input_data = InputData.set_all_to_defaults()
1012
con.process.batch_size = 4
11-
12-
_ = Batch.fake(configuration=con)
13+
return con
1314

1415

15-
def test_model_save_to_netcdf():
16+
def test_model(configuration):
17+
_ = Batch.fake(configuration=configuration)
1618

17-
con = Configuration()
18-
con.process.batch_size = 4
1919

20+
def test_model_save_to_netcdf(configuration):
2021
with tempfile.TemporaryDirectory() as dirpath:
21-
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)
22+
Batch.fake(configuration=configuration).save_netcdf(path=dirpath, batch_i=0)
2223

2324
assert os.path.exists(f"{dirpath}/satellite/0.nc")
2425

2526

26-
def test_model_load_from_netcdf():
27-
28-
con = Configuration()
29-
con.process.batch_size = 4
30-
27+
def test_model_load_from_netcdf(configuration):
3128
with tempfile.TemporaryDirectory() as dirpath:
32-
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)
29+
Batch.fake(configuration=configuration).save_netcdf(path=dirpath, batch_i=0)
3330

3431
batch = Batch.load_netcdf(batch_idx=0, local_netcdf_path=dirpath)
3532

0 commit comments

Comments
 (0)