66import pandas as pd
77import xarray as xr
88
9+ from nowcasting_dataset .consts import NWP_VARIABLE_NAMES , SAT_VARIABLE_NAMES
910from nowcasting_dataset .data_sources .gsp .gsp_model import GSP
1011from nowcasting_dataset .data_sources .metadata .metadata_model import Metadata
1112from nowcasting_dataset .data_sources .nwp .nwp_model import NWP
@@ -64,7 +65,7 @@ def nwp_fake(
6465 create_image_array (
6566 seq_length_5 = seq_length_5 ,
6667 image_size_pixels = image_size_pixels ,
67- number_channels = number_nwp_channels ,
68+ channels = NWP_VARIABLE_NAMES [ 0 : number_nwp_channels ] ,
6869 )
6970 for _ in range (batch_size )
7071 ]
@@ -107,7 +108,7 @@ def satellite_fake(
107108 create_image_array (
108109 seq_length_5 = seq_length_5 ,
109110 image_size_pixels = satellite_image_size_pixels ,
110- number_channels = number_satellite_channels ,
111+ channels = SAT_VARIABLE_NAMES [ 0 : number_satellite_channels ] ,
111112 )
112113 for _ in range (batch_size )
113114 ]
@@ -163,14 +164,14 @@ def create_image_array(
163164 dims = ("time" , "x" , "y" , "channels" ),
164165 seq_length_5 = 19 ,
165166 image_size_pixels = 64 ,
166- number_channels = 7 ,
167+ channels = SAT_VARIABLE_NAMES ,
167168):
168169 """ Create Satellite or NWP fake image data"""
169170 ALL_COORDS = {
170171 "time" : pd .date_range ("2021-01-01" , freq = "5T" , periods = seq_length_5 ),
171172 "x" : np .random .randint (low = 0 , high = 1000 , size = image_size_pixels ),
172173 "y" : np .random .randint (low = 0 , high = 1000 , size = image_size_pixels ),
173- "channels" : np .arange ( number_channels ),
174+ "channels" : np .array ( channels ),
174175 }
175176 coords = [(dim , ALL_COORDS [dim ]) for dim in dims ]
176177 image_data_array = xr .DataArray (
@@ -179,7 +180,7 @@ def create_image_array(
179180 seq_length_5 ,
180181 image_size_pixels ,
181182 image_size_pixels ,
182- number_channels ,
183+ len ( channels ) ,
183184 )
184185 ),
185186 coords = coords ,
0 commit comments