22
33Wanted to keep this out of the testing frame works, as other repos, might want to use this
44"""
5+ from typing import List
6+
57import numpy as np
68import pandas as pd
79import xarray as xr
1517from nowcasting_dataset .data_sources .sun .sun_model import Sun
1618from nowcasting_dataset .data_sources .topographic .topographic_model import Topographic
1719from nowcasting_dataset .dataset .xr_utils import (
18- convert_data_array_to_dataset ,
19- join_list_data_array_to_batch_dataset ,
20+ convert_coordinates_to_indexes ,
21+ convert_coordinates_to_indexes_for_list_datasets ,
2022 join_list_dataset_to_batch_dataset ,
2123)
2224
@@ -28,7 +30,7 @@ def gsp_fake(
2830):
2931 """Create fake data"""
3032 # make batch of arrays
31- xr_arrays = [
33+ xr_datasets = [
3234 create_gsp_pv_dataset (
3335 seq_length = seq_length_30 ,
3436 freq = "30T" ,
@@ -37,8 +39,11 @@ def gsp_fake(
3739 for _ in range (batch_size )
3840 ]
3941
42+ # change dimensions to dimension indexes
43+ xr_datasets = convert_coordinates_to_indexes_for_list_datasets (xr_datasets )
44+
4045 # make dataset
41- xr_dataset = join_list_dataset_to_batch_dataset (xr_arrays )
46+ xr_dataset = join_list_dataset_to_batch_dataset (xr_datasets )
4247
4348 return GSP (xr_dataset )
4449
@@ -47,6 +52,9 @@ def metadata_fake(batch_size):
4752 """Make a xr dataset"""
4853 xr_arrays = [create_metadata_dataset () for _ in range (batch_size )]
4954
55+ # change to indexes
56+ xr_arrays = [convert_coordinates_to_indexes (xr_array ) for xr_array in xr_arrays ]
57+
5058 # make dataset
5159 xr_dataset = join_list_dataset_to_batch_dataset (xr_arrays )
5260
@@ -81,7 +89,7 @@ def nwp_fake(
8189def pv_fake (batch_size , seq_length_5 , n_pv_systems_per_batch ):
8290 """Create fake data"""
8391 # make batch of arrays
84- xr_arrays = [
92+ xr_datasets = [
8593 create_gsp_pv_dataset (
8694 seq_length = seq_length_5 ,
8795 freq = "5T" ,
@@ -90,8 +98,11 @@ def pv_fake(batch_size, seq_length_5, n_pv_systems_per_batch):
9098 for _ in range (batch_size )
9199 ]
92100
101+ # change dimensions to dimension indexes
102+ xr_datasets = convert_coordinates_to_indexes_for_list_datasets (xr_datasets )
103+
93104 # make dataset
94- xr_dataset = join_list_dataset_to_batch_dataset (xr_arrays )
105+ xr_dataset = join_list_dataset_to_batch_dataset (xr_datasets )
95106
96107 return PV (xr_dataset )
97108
@@ -150,6 +161,7 @@ def topographic_fake(batch_size, image_size_pixels):
150161 x = np .sort (np .random .randn (image_size_pixels )),
151162 y = np .sort (np .random .randn (image_size_pixels ))[::- 1 ].copy (),
152163 ),
164+ name = "data" ,
153165 )
154166 for _ in range (batch_size )
155167 ]
@@ -184,6 +196,7 @@ def create_image_array(
184196 )
185197 ),
186198 coords = coords ,
199+ name = "data" ,
187200 ) # Fake data for testing!
188201 return image_data_array
189202
@@ -197,7 +210,7 @@ def create_gsp_pv_dataset(
197210 """Create gsp or pv fake dataset"""
198211 ALL_COORDS = {
199212 "time" : pd .date_range ("2021-01-01" , freq = freq , periods = seq_length ),
200- "id" : np .random .randint ( low = 0 , high = 1000 , size = number_of_systems ),
213+ "id" : np .random .choice ( range ( 1000 ), number_of_systems , replace = False ),
201214 }
202215 coords = [(dim , ALL_COORDS [dim ]) for dim in dims ]
203216 data_array = xr .DataArray (
@@ -208,22 +221,20 @@ def create_gsp_pv_dataset(
208221 coords = coords ,
209222 ) # Fake data for testing!
210223
211- data = convert_data_array_to_dataset ( data_array )
224+ data = data_array . to_dataset ( name = "data" )
212225
213226 x_coords = xr .DataArray (
214- data = np .sort (np .random .randn (number_of_systems )),
215- dims = ["id_index" ],
216- coords = dict (
217- id_index = range (number_of_systems ),
227+ data = np .sort (
228+ np .random .choice (range (2 * number_of_systems ), number_of_systems , replace = False )
218229 ),
230+ dims = ["id" ],
219231 )
220232
221233 y_coords = xr .DataArray (
222- data = np .sort (np .random .randn (number_of_systems )),
223- dims = ["id_index" ],
224- coords = dict (
225- id_index = range (number_of_systems ),
234+ data = np .sort (
235+ np .random .choice (range (2 * number_of_systems ), number_of_systems , replace = False )
226236 ),
237+ dims = ["id" ],
227238 )
228239
229240 data ["x_coords" ] = x_coords
@@ -265,13 +276,14 @@ def create_sun_dataset(
265276 coords = coords ,
266277 ) # Fake data for testing!
267278
268- data = convert_data_array_to_dataset (data_array )
269- sun = data .rename ({"data" : "elevation" })
270- sun ["azimuth" ] = data .data
279+ sun = data_array .to_dataset (name = "elevation" )
280+ sun ["azimuth" ] = sun .elevation
271281
272282 sun .__setitem__ ("azimuth" , sun .azimuth .clip (min = 0 , max = 360 ))
273283 sun .__setitem__ ("elevation" , sun .elevation .clip (min = - 90 , max = 90 ))
274284
285+ sun = convert_coordinates_to_indexes (sun )
286+
275287 return sun
276288
277289
@@ -282,11 +294,11 @@ def create_metadata_dataset() -> xr.Dataset:
282294 "data" : pd .date_range ("2021-01-01" , freq = "5T" , periods = 1 ) + pd .Timedelta ("30T" ),
283295 }
284296
285- data = convert_data_array_to_dataset (xr .DataArray .from_dict (d ))
297+ data = (xr .DataArray .from_dict (d )). to_dataset ( name = "data" )
286298
287299 for v in ["x_meters_center" , "y_meters_center" , "object_at_center_label" ]:
288300 d : dict = {"dims" : ("t0_dt" ,), "data" : [np .random .randint (0 , 1000 )]}
289- d : xr .Dataset = convert_data_array_to_dataset (xr .DataArray .from_dict (d )).rename ({ "data" : v } )
301+ d : xr .Dataset = (xr .DataArray .from_dict (d )).to_dataset ( name = v )
290302 data [v ] = getattr (d , v )
291303
292304 return data
@@ -307,11 +319,20 @@ def create_datetime_dataset(
307319 coords = coords ,
308320 ) # Fake data
309321
310- data = convert_data_array_to_dataset ( data_array )
322+ data = data_array . to_dataset ( )
311323
312324 ds = data .rename ({"data" : "day_of_year_cos" })
313325 ds ["day_of_year_sin" ] = data .rename ({"data" : "day_of_year_sin" }).day_of_year_sin
314326 ds ["hour_of_day_cos" ] = data .rename ({"data" : "hour_of_day_cos" }).hour_of_day_cos
315327 ds ["hour_of_day_sin" ] = data .rename ({"data" : "hour_of_day_sin" }).hour_of_day_sin
316328
317329 return data
330+
331+
332+ def join_list_data_array_to_batch_dataset (data_arrays : List [xr .DataArray ]) -> xr .Dataset :
333+ """Join a list of xr.DataArrays into an xr.Dataset by concatenating on the example dim."""
334+ datasets = [
335+ convert_coordinates_to_indexes (data_arrays [i ].to_dataset ()) for i in range (len (data_arrays ))
336+ ]
337+
338+ return join_list_dataset_to_batch_dataset (datasets )
0 commit comments