Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit b0bc934

Browse files
Merge pull request #333 from openclimatefix/issue/torch-channels
change fake data to give channels names for satellite and nwp
2 parents d823b3e + 249dcc5 commit b0bc934

File tree

1 file changed

+6
-5
lines changed
  • nowcasting_dataset/data_sources

1 file changed

+6
-5
lines changed

nowcasting_dataset/data_sources/fake.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import xarray as xr
88

9+
from nowcasting_dataset.consts import NWP_VARIABLE_NAMES, SAT_VARIABLE_NAMES
910
from nowcasting_dataset.data_sources.gsp.gsp_model import GSP
1011
from nowcasting_dataset.data_sources.metadata.metadata_model import Metadata
1112
from nowcasting_dataset.data_sources.nwp.nwp_model import NWP
@@ -64,7 +65,7 @@ def nwp_fake(
6465
create_image_array(
6566
seq_length_5=seq_length_5,
6667
image_size_pixels=image_size_pixels,
67-
number_channels=number_nwp_channels,
68+
channels=NWP_VARIABLE_NAMES[0:number_nwp_channels],
6869
)
6970
for _ in range(batch_size)
7071
]
@@ -107,7 +108,7 @@ def satellite_fake(
107108
create_image_array(
108109
seq_length_5=seq_length_5,
109110
image_size_pixels=satellite_image_size_pixels,
110-
number_channels=number_satellite_channels,
111+
channels=SAT_VARIABLE_NAMES[0:number_satellite_channels],
111112
)
112113
for _ in range(batch_size)
113114
]
@@ -163,14 +164,14 @@ def create_image_array(
163164
dims=("time", "x", "y", "channels"),
164165
seq_length_5=19,
165166
image_size_pixels=64,
166-
number_channels=7,
167+
channels=SAT_VARIABLE_NAMES,
167168
):
168169
""" Create Satellite or NWP fake image data"""
169170
ALL_COORDS = {
170171
"time": pd.date_range("2021-01-01", freq="5T", periods=seq_length_5),
171172
"x": np.random.randint(low=0, high=1000, size=image_size_pixels),
172173
"y": np.random.randint(low=0, high=1000, size=image_size_pixels),
173-
"channels": np.arange(number_channels),
174+
"channels": np.array(channels),
174175
}
175176
coords = [(dim, ALL_COORDS[dim]) for dim in dims]
176177
image_data_array = xr.DataArray(
@@ -179,7 +180,7 @@ def create_image_array(
179180
seq_length_5,
180181
image_size_pixels,
181182
image_size_pixels,
182-
number_channels,
183+
len(channels),
183184
)
184185
),
185186
coords=coords,

0 commit comments

Comments
 (0)