11""" General Data Source output pydantic class. """
22from __future__ import annotations
3- import os
4- from nowcasting_dataset .filesystem .utils import make_folder
5- from nowcasting_dataset .utils import get_netcdf_filename
63
4+ import logging
5+ import os
76from pathlib import Path
8- from pydantic import BaseModel , Field
9- import pandas as pd
10- import xarray as xr
7+ from typing import List
8+
119import numpy as np
12- from typing import List , Union
13- import logging
14- from datetime import datetime
10+ from pydantic import BaseModel , Field
1511
16- from nowcasting_dataset .utils import to_numpy
12+ from nowcasting_dataset .dataset .xr_utils import PydanticXArrayDataSet
13+ from nowcasting_dataset .filesystem .utils import make_folder
14+ from nowcasting_dataset .utils import get_netcdf_filename
1715
1816logger = logging .getLogger (__name__ )
1917
2018
21- class DataSourceOutput (BaseModel ):
19+ class DataSourceOutput (PydanticXArrayDataSet ):
2220 """General Data Source output pydantic class.
2321
2422 Data source output classes should inherit from this class
2523 """
2624
27- class Config :
28- """ Allowed classes e.g. tensor.Tensor"""
29-
30- # TODO maybe there is a better way to do this
31- arbitrary_types_allowed = True
32-
33- batch_size : int = Field (
34- 0 ,
35- ge = 0 ,
36- description = "The size of this batch. If the batch size is 0, "
37- "then this item stores one data item i.e Example" ,
38- )
25+ __slots__ = []
3926
4027 def get_name (self ) -> str :
41- """ Get the name of the class """
28+ """Get the name of the class"""
4229 return self .__class__ .__name__ .lower ()
4330
44- def to_numpy (self ):
45- """Change to numpy"""
46- for k , v in self .dict ().items ():
47- self .__setattr__ (k , to_numpy (v ))
48-
49- def to_xr_data_array (self ):
50- """ Change to xr DataArray"""
51- raise NotImplementedError ()
52-
53- @staticmethod
54- def create_batch_from_examples (data ):
55- """
56- Join a list of data source items to a batch.
57-
58- Note that this only works for numpy objects, so objects are changed into numpy
59- """
60- _ = [d .to_numpy () for d in data ]
61-
62- # use the first item in the list, and then update each item
63- batch = data [0 ]
64- for k in batch .dict ().keys ():
65-
66- # set batch size to the list of the items
67- if k == "batch_size" :
68- batch .batch_size = len (data )
69- else :
70-
71- # get list of one variable from the list of data items.
72- one_variable_list = [d .__getattribute__ (k ) for d in data ]
73- batch .__setattr__ (k , np .stack (one_variable_list , axis = 0 ))
74-
75- return batch
76-
77- def split (self ) -> List [DataSourceOutput ]:
78- """
79- Split the datasource from a batch to a list of items
80-
81- Returns: List of single data source items
82- """
83- cls = self .__class__
84-
85- items = []
86- for batch_idx in range (self .batch_size ):
87- d = {k : v [batch_idx ] for k , v in self .dict ().items () if k != "batch_size" }
88- d ["batch_size" ] = 0
89- items .append (cls (** d ))
90-
91- return items
92-
93- def to_xr_dataset (self , ** kwargs ):
94- """ Make a xr dataset. Each data source needs to define this """
95- raise NotImplementedError
96-
97- def from_xr_dataset (self ):
98- """ Load from xr dataset. Each data source needs to define this """
99- raise NotImplementedError
100-
101- def get_datetime_index (self ):
102- """ Datetime index for the data """
103- pass
104-
105- def save_netcdf (self , batch_i : int , path : Path , xr_dataset : xr .Dataset ):
31+ def save_netcdf (self , batch_i : int , path : Path ):
10632 """
10733 Save batch to netcdf file
10834
10935 Args:
11036 batch_i: the batch id, used to make the filename
11137 path: the path where it will be saved. This can be local or in the cloud.
112- xr_dataset: xr dataset that has batch information in it
11338 """
11439 filename = get_netcdf_filename (batch_i )
11540
@@ -124,77 +49,46 @@ def save_netcdf(self, batch_i: int, path: Path, xr_dataset: xr.Dataset):
12449 # make file
12550 local_filename = os .path .join (folder , filename )
12651
127- encoding = {name : {"compression" : "lzf" } for name in xr_dataset .data_vars }
128- xr_dataset .to_netcdf (local_filename , engine = "h5netcdf" , mode = "w" , encoding = encoding )
129-
130- def select_time_period (
131- self ,
132- keys : List [str ],
133- history_minutes : int ,
134- forecast_minutes : int ,
135- t0_dt_of_first_example : Union [datetime , pd .Timestamp ],
136- ):
137- """
138- Selects a subset of data between the indicies of [start, end] for each key in keys
139-
140- Note that class is edited so nothing is returned.
141-
142- Args:
143- keys: Keys in batch to use
144- t0_dt_of_first_example: datetime of the current time (t0) in the first example of the batch
145- history_minutes: How many minutes of history to use
146- forecast_minutes: How many minutes of future data to use for forecasting
147-
148- """
149- logger .debug (
150- f"Taking a sub-selection of the batch data based on a history minutes of { history_minutes } "
151- f"and forecast minutes of { forecast_minutes } "
152- )
52+ encoding = {name : {"compression" : "lzf" } for name in self .data_vars }
53+ self .to_netcdf (local_filename , engine = "h5netcdf" , mode = "w" , encoding = encoding )
15354
154- start_time_of_first_batch = t0_dt_of_first_example - pd .to_timedelta (
155- f"{ history_minutes } minute 30 second"
156- )
157- end_time_of_first_example = t0_dt_of_first_example + pd .to_timedelta (
158- f"{ forecast_minutes } minute 30 second"
159- )
16055
161- logger . debug ( f"New start time for first batch is { start_time_of_first_batch } " )
162- logger . debug ( f"New end time for first batch is { end_time_of_first_example } " )
56+ class DataSourceOutputML ( BaseModel ):
57+ """General Data Source output pydantic class.
16358
164- start_time_of_first_example = to_numpy ( start_time_of_first_batch )
165- end_time_of_first_example = to_numpy ( end_time_of_first_example )
59+ Data source output classes should inherit from this class
60+ """
16661
167- if self .get_datetime_index () is not None :
62+ class Config :
63+ """Allowed classes e.g. tensor.Tensor"""
16864
169- time_of_first_example = to_numpy (pd .to_datetime (self .get_datetime_index ()[0 ]))
65+ # TODO maybe there is a better way to do this
66+ arbitrary_types_allowed = True
17067
171- # find the start and end index, that we will then use to slice the data
172- start_i , end_i = np .searchsorted (
173- time_of_first_example , [start_time_of_first_example , end_time_of_first_example ]
174- )
68+ batch_size : int = Field (
69+ 0 ,
70+ ge = 0 ,
71+ description = "The size of this batch. If the batch size is 0, "
72+ "then this item stores one data item i.e Example" ,
73+ )
17574
176- # slice all the data
177- for key in keys :
178- if "time" in self .__getattribute__ (key ).dims :
179- self .__setattr__ (
180- key , self .__getattribute__ (key ).isel (time = slice (start_i , end_i ))
181- )
182- elif "time_30" in self .__getattribute__ (key ).dims :
183- self .__setattr__ (
184- key , self .__getattribute__ (key ).isel (time_30 = slice (start_i , end_i ))
185- )
75+ def get_name (self ) -> str :
76+ """Get the name of the class"""
77+ return self .__class__ .__name__ .lower ()
18678
187- logger .debug (f"{ self .__class__ .__name__ } { key } : { self .__getattribute__ (key ).shape } " )
79+ def get_datetime_index (self ):
80+ """Datetime index for the data"""
81+ pass
18882
18983
19084def pad_nans (array , pad_width ) -> np .ndarray :
191- """ Pad nans with nans"""
85+ """Pad nans with nans"""
19286 array = array .astype (np .float32 )
19387 return np .pad (array , pad_width , constant_values = np .NaN )
19488
19589
19690def pad_data (
197- data : DataSourceOutput ,
91+ data : DataSourceOutputML ,
19892 pad_size : int ,
19993 one_dimensional_arrays : List [str ],
20094 two_dimensional_arrays : List [str ],
0 commit comments