Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 3d9d1c2

Browse files
Merge pull request #424 from openclimatefix/issue/390-split
update to split on training in one year, test after a date
2 parents 9a79c03 + 88b7088 commit 3d9d1c2

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

nowcasting_dataset/config/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class Process(BaseModel):
292292
),
293293
)
294294
split_method: split.SplitMethod = Field(
295-
split.SplitMethod.DAY,
295+
split.SplitMethod.DAY_RANDOM_TEST_DATE,
296296
description=(
297297
"The method used to split the t0 datetimes into train, validation and test sets."
298298
),

nowcasting_dataset/dataset/split/split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class SplitName(Enum):
4949
def split_data(
5050
datetimes: Union[List[pd.Timestamp], pd.DatetimeIndex],
5151
method: SplitMethod,
52-
train_test_validation_split: Tuple[int] = (3, 1, 1),
52+
train_test_validation_split: Tuple[int, int, int] = (3, 1, 1),
5353
train_test_validation_specific: TrainValidationTestSpecific = (
5454
default_train_test_validation_specific
5555
),

nowcasting_dataset/manager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,15 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne
166166
t0_datetimes = self.get_t0_datetimes_across_all_data_sources(
167167
freq=self.config.process.t0_datetime_frequency
168168
)
169+
# TODO: move hard code values to config file #426
169170
split_t0_datetimes = split.split_data(
170-
datetimes=t0_datetimes, method=self.config.process.split_method
171+
datetimes=t0_datetimes,
172+
method=self.config.process.split_method,
173+
train_test_validation_split=(3, 0, 1),
174+
train_validation_test_datetime_split=[
175+
pd.Timestamp("2020-01-01"),
176+
pd.Timestamp("2021-01-01"),
177+
],
171178
)
172179
for split_name, datetimes_for_split in split_t0_datetimes._asdict().items():
173180
n_batches = self._get_n_batches_for_split_name(split_name)

0 commit comments

Comments
 (0)