Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit e6ef521

Browse files
committed
Implemented test for DataSourceList.sample_spatial_and_temporal_positions_for_examples
1 parent ec8b811 commit e6ef521

File tree

3 files changed

+22
-34
lines changed

3 files changed

+22
-34
lines changed

nowcasting_dataset/data_source_list.py renamed to nowcasting_dataset/data_sources/data_source_list.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,9 @@ def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeInde
5353

5454
return t0_datetimes
5555

56-
def sample_position_of_every_example_of_every_split(
57-
self,
58-
t0_datetimes: pd.DatetimeIndex,
59-
split_method: SplitMethod,
60-
n_examples_per_split: dict[SplitName, int],
61-
) -> dict[SplitName, pd.DataFrame]:
56+
def sample_spatial_and_temporal_positions_for_examples(
57+
self, t0_datetimes: pd.DatetimeIndex, n_examples: int
58+
) -> pd.DataFrame:
6259
"""
6360
Computes the geospatial and temporal position of each training example.
6461
@@ -68,33 +65,21 @@ def sample_position_of_every_example_of_every_split(
6865
Args:
6966
t0_datetimes: All available t0 datetimes. Can be computed with
7067
`DataSourceList.get_t0_datetimes_across_all_data_sources()`
71-
split_method: The method used to split data into train, validation, and test.
72-
n_examples_per_split: The number of examples requested for each split.
68+
n_examples: The number of examples requested.
7369
7470
Returns:
75-
A dict where the keys are a SplitName, and the values are a pd.DataFrame.
76-
Each row of each DataFrame specifies the position of each example, using
71+
Each row of each the DataFrame specifies the position of each example, using
7772
columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
7873
"""
79-
# Split t0_datetimes into train, test and validation sets.
80-
t0_datetimes_per_split = split_data(datetimes=t0_datetimes, method=split_method)
81-
t0_datetimes_per_split = t0_datetimes_per_split._asdict()
82-
8374
data_source_which_defines_geo_position = self[0]
84-
85-
positions_per_split: dict[SplitName, pd.DataFrame] = {}
86-
for split_name, t0_datetimes_for_split in t0_datetimes_per_split.items():
87-
n_examples = n_examples_per_split[split_name]
88-
shuffled_t0_datetimes = np.random.choice(t0_datetimes_for_split, shape=n_examples)
89-
x_locations, y_locations = data_source_which_defines_geo_position.get_locations(
90-
shuffled_t0_datetimes
91-
)
92-
positions_per_split[split_name] = pd.DataFrame(
93-
{
94-
"t0_datetime_UTC": shuffled_t0_datetimes,
95-
"x_center_OSGB": x_locations,
96-
"y_center_OSGB": y_locations,
97-
}
98-
)
99-
100-
return positions_per_split
75+
shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples)
76+
x_locations, y_locations = data_source_which_defines_geo_position.get_locations(
77+
shuffled_t0_datetimes
78+
)
79+
return pd.DataFrame(
80+
{
81+
"t0_datetime_UTC": shuffled_t0_datetimes,
82+
"x_center_OSGB": x_locations,
83+
"y_center_OSGB": y_locations,
84+
}
85+
)

nowcasting_dataset/dataset/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource
1515
from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource
1616
from nowcasting_dataset.dataset import datasets
17-
from nowcasting_dataset.dataset.split.split import split_data, SplitMethod, SplitName
18-
from nowcasting_dataset.data_source_list import DataSourceList
17+
from nowcasting_dataset.dataset.split.split import split_data, SplitMethod
18+
from nowcasting_dataset.data_sources.data_source_list import DataSourceList
1919

2020

2121
with warnings.catch_warnings():

nowcasting_dataset/dataset/split/split.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ class SplitName(Enum):
3939
TEST = "test"
4040

4141

42-
SplitData = namedtuple(typename="SplitData", field_names=["train", "validation", "test"])
42+
SplitData = namedtuple(
43+
typename="SplitData",
44+
field_names=[SplitName.TRAIN.value, SplitName.VALIDATION.value, SplitName.TEST.value],
45+
)
4346

4447

4548
def split_data(

0 commit comments

Comments
 (0)