11""" batch functions """
22import logging
3+ import os
34from pathlib import Path
4- from typing import List , Optional , Union
5+ from typing import List , Optional , Union , Dict
56
67import xarray as xr
78from pydantic import BaseModel , Field
89
10+ from nowcasting_dataset .filesystem .utils import make_folder
11+
912from nowcasting_dataset .config .model import Configuration
1013
1114from nowcasting_dataset .data_sources .datetime .datetime_model import Datetime
@@ -72,21 +75,23 @@ class Batch(Example):
7275 "then this item stores one data item" ,
7376 )
7477
75- def batch_to_dataset (self ) -> xr .Dataset :
78+ def batch_to_dict_dataset (self ) -> Dict [ str , xr .Dataset ] :
7679 """Change batch to xr.Dataset so it can be saved and compressed"""
77- return batch_to_dataset (batch = self )
80+ return batch_to_dict_dataset (batch = self )
7881
7982 @staticmethod
80- def load_batch_from_dataset (xr_dataset : xr .Dataset ):
81- """Change xr.Datatset to Batch object"""
83+ def load_batch_from_dict_dataset (xr_dataset : Dict [ str , xr .Dataset ] ):
84+ """Change dictionary of xr.Datatset to Batch object"""
8285 # get a list of data sources
8386 data_sources_names = Example .__fields__ .keys ()
8487
8588 # collect data sources
8689 data_sources_dict = {}
8790 for data_source_name in data_sources_names :
8891 cls = Example .__fields__ [data_source_name ].type_
89- data_sources_dict [data_source_name ] = cls .from_xr_dataset (xr_dataset = xr_dataset )
92+ data_sources_dict [data_source_name ] = cls .from_xr_dataset (
93+ xr_dataset = xr_dataset [data_source_name ]
94+ )
9095
9196 data_sources_dict ["batch_size" ] = data_sources_dict ["metadata" ].batch_size
9297
@@ -168,43 +173,57 @@ def save_netcdf(self, batch_i: int, path: Path):
168173 path: the path where it will be saved. This can be local or in the cloud.
169174
170175 """
171- batch_xr = self .batch_to_dataset ()
176+ batch_xr = self .batch_to_dict_dataset ()
172177
173- encoding = {name : {"compression" : "lzf" } for name in batch_xr .data_vars }
174- filename = get_netcdf_filename (batch_i )
175- local_filename = path / filename
176- batch_xr .to_netcdf (local_filename , engine = "h5netcdf" , mode = "w" , encoding = encoding )
178+ for data_source in self .data_sources :
179+ xr_dataset = batch_xr [data_source .get_name ()]
180+ data_source .save_netcdf (batch_i = batch_i , path = path , xr_dataset = xr_dataset )
177181
178182 @staticmethod
179- def load_netcdf (local_netcdf_filename : Path ):
183+ def load_netcdf (local_netcdf_path : Union [ Path , str ], batch_idx : int ):
180184 """Load batch from netcdf file"""
181- netcdf_batch = xr . load_dataset ( local_netcdf_filename )
185+ data_sources_names = Example . __fields__ . keys ( )
182186
183- return Batch .load_batch_from_dataset (netcdf_batch )
187+ # collect data sources
188+ batch_dict = {}
189+ for data_source_name in data_sources_names :
184190
191+ local_netcdf_filename = os .path .join (
192+ local_netcdf_path , data_source_name , f"{ batch_idx } .nc"
193+ )
194+ xr_dataset = xr .load_dataset (local_netcdf_filename )
185195
186- def batch_to_dataset (batch : Batch ) -> xr .Dataset :
187- """Concat all the individual fields in an Example into a single Dataset.
196+ batch_dict [data_source_name ] = xr_dataset
197+
198+ return Batch .load_batch_from_dict_dataset (batch_dict )
199+
200+
201+ def batch_to_dict_dataset (batch : Batch ) -> Dict [str , xr .Dataset ]:
202+ """Concat all the individual fields in an Example into a dictionary of Datasets.
188203
189204 Args:
190205 batch: List of Example objects, which together constitute a single batch.
191206 """
192- datasets = []
207+ individual_datasets = {}
208+ split_batch = batch .split ()
209+
210+ # loop over each data source
211+ for data_source in split_batch [0 ].data_sources :
193212
194- # loop over each item in the batch
195- for i , example in enumerate ( batch . split ()):
213+ datasets = []
214+ name = data_source . get_name ()
196215
197- individual_datasets = []
216+ # loop over each item in the batch
217+ for i , example in enumerate (split_batch ):
198218
199- for data_source in example .data_sources :
200219 if data_source is not None :
201- individual_datasets .append (data_source .to_xr_dataset (i ))
220+ datasets .append (getattr ( split_batch [ i ], name ) .to_xr_dataset (i ))
202221
203222 # Merge
204- merged_ds = xr .merge ( individual_datasets )
205- datasets . append ( merged_ds )
223+ merged_ds = xr .concat ( datasets , dim = "example" )
224+ individual_datasets [ name ] = merged_ds
206225
207- return xr . concat ( datasets , dim = "example" )
226+ return individual_datasets
208227
209228
210229def write_batch_locally (batch : Union [Batch , dict ], batch_i : int , path : Path ):
@@ -219,8 +238,4 @@ def write_batch_locally(batch: Union[Batch, dict], batch_i: int, path: Path):
219238 if type (batch ):
220239 batch = Batch (** batch )
221240
222- dataset = batch .batch_to_dataset ()
223- encoding = {name : {"compression" : "lzf" } for name in dataset .data_vars }
224- filename = get_netcdf_filename (batch_i )
225- local_filename = path / filename
226- dataset .to_netcdf (local_filename , engine = "h5netcdf" , mode = "w" , encoding = encoding )
241+ batch .save_netcdf (batch_i = batch_i , path = path )
0 commit comments