Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 2d7d977

Browse files
Merge pull request #216 from openclimatefix/issue/202-modality-files
Issue/202 modality files
2 parents e5b5d88 + 0046e21 commit 2d7d977

File tree

18 files changed

+153
-60
lines changed

18 files changed

+153
-60
lines changed

nowcasting_dataset/data_sources/data_source.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def get_batch(
121121
output.to_numpy()
122122
examples.append(output)
123123

124+
# could add option here, to save each data source using
125+
# 1. # DataSourceOutput.to_xr_dataset() to make it a dataset
126+
# 2. DataSourceOutput.save_netcdf(), save to netcdf
124127
return DataSourceOutput.create_batch_from_examples(examples)
125128

126129
def datetime_index(self) -> pd.DatetimeIndex:

nowcasting_dataset/data_sources/datasource_output.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
""" General Data Source output pydantic class. """
22
from __future__ import annotations
3+
import os
4+
from nowcasting_dataset.filesystem.utils import make_folder
5+
from nowcasting_dataset.utils import get_netcdf_filename
6+
7+
from pathlib import Path
38
from pydantic import BaseModel, Field
49
import pandas as pd
510
import xarray as xr
@@ -32,6 +37,10 @@ class Config:
3237
"then this item stores one data item i.e Example",
3338
)
3439

40+
def get_name(self) -> str:
41+
""" Get the name of the class """
42+
return self.__class__.__name__.lower()
43+
3544
def to_numpy(self):
3645
"""Change to numpy"""
3746
for k, v in self.dict().items():
@@ -93,6 +102,31 @@ def get_datetime_index(self):
93102
""" Datetime index for the data """
94103
pass
95104

105+
def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset):
106+
"""
107+
Save batch to netcdf file
108+
109+
Args:
110+
batch_i: the batch id, used to make the filename
111+
path: the path where it will be saved. This can be local or in the cloud.
112+
xr_dataset: xr dataset that has batch information in it
113+
"""
114+
filename = get_netcdf_filename(batch_i)
115+
116+
name = self.get_name()
117+
118+
# make folder
119+
folder = os.path.join(path, name)
120+
if batch_i == 0:
121+
# only need to make the folder once, or check that there folder is there once
122+
make_folder(path=folder)
123+
124+
# make file
125+
local_filename = os.path.join(folder, filename)
126+
127+
encoding = {name: {"compression": "lzf"} for name in xr_dataset.data_vars}
128+
xr_dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
129+
96130
def select_time_period(
97131
self,
98132
keys: List[str],

nowcasting_dataset/dataset/batch.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
""" batch functions """
22
import logging
3+
import os
34
from pathlib import Path
4-
from typing import List, Optional, Union
5+
from typing import List, Optional, Union, Dict
56

67
import xarray as xr
78
from pydantic import BaseModel, Field
89

10+
from nowcasting_dataset.filesystem.utils import make_folder
11+
912
from nowcasting_dataset.config.model import Configuration
1013

1114
from nowcasting_dataset.data_sources.datetime.datetime_model import Datetime
@@ -72,21 +75,23 @@ class Batch(Example):
7275
"then this item stores one data item",
7376
)
7477

75-
def batch_to_dataset(self) -> xr.Dataset:
78+
def batch_to_dict_dataset(self) -> Dict[str, xr.Dataset]:
7679
"""Change batch to xr.Dataset so it can be saved and compressed"""
77-
return batch_to_dataset(batch=self)
80+
return batch_to_dict_dataset(batch=self)
7881

7982
@staticmethod
80-
def load_batch_from_dataset(xr_dataset: xr.Dataset):
81-
"""Change xr.Datatset to Batch object"""
83+
def load_batch_from_dict_dataset(xr_dataset: Dict[str, xr.Dataset]):
84+
"""Change dictionary of xr.Datatset to Batch object"""
8285
# get a list of data sources
8386
data_sources_names = Example.__fields__.keys()
8487

8588
# collect data sources
8689
data_sources_dict = {}
8790
for data_source_name in data_sources_names:
8891
cls = Example.__fields__[data_source_name].type_
89-
data_sources_dict[data_source_name] = cls.from_xr_dataset(xr_dataset=xr_dataset)
92+
data_sources_dict[data_source_name] = cls.from_xr_dataset(
93+
xr_dataset=xr_dataset[data_source_name]
94+
)
9095

9196
data_sources_dict["batch_size"] = data_sources_dict["metadata"].batch_size
9297

@@ -168,43 +173,57 @@ def save_netcdf(self, batch_i: int, path: Path):
168173
path: the path where it will be saved. This can be local or in the cloud.
169174
170175
"""
171-
batch_xr = self.batch_to_dataset()
176+
batch_xr = self.batch_to_dict_dataset()
172177

173-
encoding = {name: {"compression": "lzf"} for name in batch_xr.data_vars}
174-
filename = get_netcdf_filename(batch_i)
175-
local_filename = path / filename
176-
batch_xr.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
178+
for data_source in self.data_sources:
179+
xr_dataset = batch_xr[data_source.get_name()]
180+
data_source.save_netcdf(batch_i=batch_i, path=path, xr_dataset=xr_dataset)
177181

178182
@staticmethod
179-
def load_netcdf(local_netcdf_filename: Path):
183+
def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int):
180184
"""Load batch from netcdf file"""
181-
netcdf_batch = xr.load_dataset(local_netcdf_filename)
185+
data_sources_names = Example.__fields__.keys()
182186

183-
return Batch.load_batch_from_dataset(netcdf_batch)
187+
# collect data sources
188+
batch_dict = {}
189+
for data_source_name in data_sources_names:
184190

191+
local_netcdf_filename = os.path.join(
192+
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
193+
)
194+
xr_dataset = xr.load_dataset(local_netcdf_filename)
185195

186-
def batch_to_dataset(batch: Batch) -> xr.Dataset:
187-
"""Concat all the individual fields in an Example into a single Dataset.
196+
batch_dict[data_source_name] = xr_dataset
197+
198+
return Batch.load_batch_from_dict_dataset(batch_dict)
199+
200+
201+
def batch_to_dict_dataset(batch: Batch) -> Dict[str, xr.Dataset]:
202+
"""Concat all the individual fields in an Example into a dictionary of Datasets.
188203
189204
Args:
190205
batch: List of Example objects, which together constitute a single batch.
191206
"""
192-
datasets = []
207+
individual_datasets = {}
208+
split_batch = batch.split()
209+
210+
# loop over each data source
211+
for data_source in split_batch[0].data_sources:
193212

194-
# loop over each item in the batch
195-
for i, example in enumerate(batch.split()):
213+
datasets = []
214+
name = data_source.get_name()
196215

197-
individual_datasets = []
216+
# loop over each item in the batch
217+
for i, example in enumerate(split_batch):
198218

199-
for data_source in example.data_sources:
200219
if data_source is not None:
201-
individual_datasets.append(data_source.to_xr_dataset(i))
220+
datasets.append(getattr(split_batch[i], name).to_xr_dataset(i))
202221

203222
# Merge
204-
merged_ds = xr.merge(individual_datasets)
205-
datasets.append(merged_ds)
223+
merged_ds = xr.concat(datasets, dim="example")
224+
individual_datasets[name] = merged_ds
206225

207-
return xr.concat(datasets, dim="example")
226+
return individual_datasets
208227

209228

210229
def write_batch_locally(batch: Union[Batch, dict], batch_i: int, path: Path):
@@ -219,8 +238,4 @@ def write_batch_locally(batch: Union[Batch, dict], batch_i: int, path: Path):
219238
if type(batch):
220239
batch = Batch(**batch)
221240

222-
dataset = batch.batch_to_dataset()
223-
encoding = {name: {"compression": "lzf"} for name in dataset.data_vars}
224-
filename = get_netcdf_filename(batch_i)
225-
local_filename = path / filename
226-
dataset.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding)
241+
batch.save_netcdf(batch_i=batch_i, path=path)

nowcasting_dataset/dataset/datasets.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from nowcasting_dataset import data_sources
1717
from nowcasting_dataset import utils as nd_utils
18-
from nowcasting_dataset.filesystem.utils import download_to_local
18+
from nowcasting_dataset.filesystem.utils import download_to_local, delete_all_files_in_temp_path
1919
from nowcasting_dataset.config.model import Configuration
2020
from nowcasting_dataset.consts import (
2121
GSP_YIELD,
@@ -185,21 +185,24 @@ def __getitem__(self, batch_idx: int) -> Batch:
185185
"batch_idx must be in the range" f" [0, {self.n_batches}), not {batch_idx}!"
186186
)
187187
netcdf_filename = nd_utils.get_netcdf_filename(batch_idx)
188-
remote_netcdf_filename = os.path.join(self.src_path, netcdf_filename)
189-
local_netcdf_filename = os.path.join(self.tmp_path, netcdf_filename)
188+
# remote_netcdf_folder = os.path.join(self.src_path, netcdf_filename)
189+
# local_netcdf_filename = os.path.join(self.tmp_path, netcdf_filename)
190190

191191
if self.cloud in ["gcp", "aws"]:
192+
# TODO check this works for mulitple files
192193
download_to_local(
193-
remote_filename=remote_netcdf_filename,
194-
local_filename=local_netcdf_filename,
194+
remote_filename=self.src_path,
195+
local_filename=self.tmp_path,
195196
)
197+
local_netcdf_folder = self.tmp_path
196198
else:
197-
local_netcdf_filename = remote_netcdf_filename
199+
local_netcdf_folder = self.src_path
198200

199-
batch = Batch.load_netcdf(local_netcdf_filename)
201+
batch = Batch.load_netcdf(local_netcdf_folder, batch_idx=batch_idx)
200202
# netcdf_batch = xr.load_dataset(local_netcdf_filename)
201203
if self.cloud != "local":
202-
os.remove(local_netcdf_filename)
204+
# remove files in a folder, but not the folder itself
205+
delete_all_files_in_temp_path(self.src_path)
203206

204207
# batch = example.xr_to_example(batch_xr=netcdf_batch, required_keys=self.required_keys)
205208

nowcasting_dataset/filesystem/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,9 @@ def upload_one_file(
127127
"""
128128
filesystem = fsspec.open(remote_filename).fs
129129
filesystem.put(local_filename, remote_filename)
130+
131+
132+
def make_folder(path: Union[str, Path]):
133+
""" Make folder """
134+
filesystem = fsspec.open(path).fs
135+
filesystem.mkdir(path)

scripts/generate_data_for_tests/get_test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,4 @@
148148
c.process.sat_channels = c.process.sat_channels[0:1]
149149

150150
f = Batch.fake(configuration=c)
151-
f.save_netcdf(batch_i=0, path=Path(f"{local_path}/tests/data"))
151+
f.save_netcdf(batch_i=0, path=Path(f"{local_path}/tests/data/batch"))

tests/data/batch/datetime/0.nc

20.9 KB
Binary file not shown.

tests/data/batch/gsp/0.nc

27 KB
Binary file not shown.

tests/data/batch/metadata/0.nc

17.3 KB
Binary file not shown.

tests/data/batch/nwp/0.nc

2.3 MB
Binary file not shown.

0 commit comments

Comments
 (0)