Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 854c0bd

Browse files
Issue/624 pv types (#641)
* assert the types we want of pv data - TDD * change data types to float32 or int32 * tidy Co-authored-by: Jacob Bieker <jacob@bieker.tech>
1 parent 0711c52 commit 854c0bd

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

nowcasting_dataset/data_sources/pv/pv_data_source.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,14 +396,17 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
396396
data=pv_system_row_number,
397397
dims=["id"],
398398
)
399-
pv["x_osgb"] = x_coords
400-
pv["y_osgb"] = y_coords
399+
pv["x_osgb"] = x_coords.astype("float32")
400+
pv["y_osgb"] = y_coords.astype("float32")
401401
pv["pv_system_row_number"] = pv_system_row_number
402402

403403
# pad out so that there are always n_pv_systems_per_example, pad with zeros
404404
pad_n = self.n_pv_systems_per_example - len(pv.id)
405405
pv = pv.pad(id=(0, pad_n), power_mw=((0, 0), (0, pad_n)), constant_values=0)
406406

407+
# format id
408+
pv.__setitem__("id", pv.id.astype("int32"))
409+
407410
return pv
408411

409412
def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLocation]:

nowcasting_dataset/dataset/xr_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ def join_list_dataset_to_batch_dataset(datasets: list[xr.Dataset]) -> xr.Dataset
1717
new_dataset = dataset.expand_dims(dim="example").assign_coords(example=("example", [i]))
1818
new_datasets.append(new_dataset)
1919

20-
return xr.concat(new_datasets, dim="example")
20+
joined_dataset = xr.concat(new_datasets, dim="example")
21+
22+
# format example index
23+
joined_dataset.__setitem__("example", joined_dataset.example.astype("int32"))
24+
25+
return joined_dataset
2126

2227

2328
def convert_coordinates_to_indexes_for_list_datasets(
@@ -43,7 +48,7 @@ def convert_coordinates_to_indexes(dataset: xr.Dataset) -> xr.Dataset:
4348

4449
for original_dim_name in original_dim_names:
4550
original_coords = dataset[original_dim_name]
46-
new_index_coords = np.arange(len(original_coords))
51+
new_index_coords = np.arange(len(original_coords)).astype("int32")
4752
new_index_dim_name = f"{original_dim_name}_index"
4853
dataset[original_dim_name] = new_index_coords
4954
dataset = dataset.rename({original_dim_name: new_index_dim_name})

tests/data_sources/pv/test_pv_data_source.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def test_get_example_and_batch(): # noqa: D103
7171
# start at 6, to avoid some nans
7272
batch = pv_data_source.get_batch(locations=locations[6:16])
7373
assert batch.power_mw.shape == (10, 19, DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE)
74+
assert str(batch.x_osgb.dtype) == "float32"
75+
assert str(batch.y_osgb.dtype) == "float32"
76+
assert str(batch.id.dtype) == "int32"
77+
assert str(batch.example.dtype) == "int32"
78+
assert str(batch.id_index.dtype) == "int32"
79+
assert str(batch.time_index.dtype) == "int32"
7480

7581

7682
def test_drop_pv_systems_which_produce_overnight(): # noqa: D103

0 commit comments

Comments
 (0)