Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 2435644

Browse files
committed
Merge branch 'issue/n-batches-manager-dyanmic'
2 parents 0a89243 + 0c22ec8 commit 2435644

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

nowcasting_dataset/manager/manager_live.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example(
7676

7777
n_batches_requested = int(np.ceil(n_gsps / self.config.process.batch_size))
7878

79-
n_examples = n_batches_requested * self.config.process.batch_size
8079
logger.debug(
8180
f"Creating {n_batches_requested:,d} batches x {self.config.process.batch_size:,d}"
82-
f" examples per batch = {n_examples:,d} examples for {split_name}."
81+
f" examples per batch = {n_batches_requested:,d} examples for {split_name}."
8382
)
8483

8584
locations = self.sample_spatial_and_temporal_locations_for_examples(
86-
t0_datetime=datetimes_for_split[0],
85+
t0_datetime=datetimes_for_split[0], n_examples=n_gsps
8786
)
87+
8888
metadata = Metadata(
8989
batch_size=self.config.process.batch_size, space_time_locations=locations
9090
)
@@ -95,7 +95,7 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example(
9595
metadata.save_to_csv(path_for_csv)
9696

9797
def sample_spatial_and_temporal_locations_for_examples(
98-
self, t0_datetime: datetime
98+
self, t0_datetime: datetime, n_examples: Optional[int] = None
9999
) -> List[SpaceTimeLocation]:
100100
"""
101101
Computes the geospatial and temporal locations for each training example.
@@ -122,6 +122,10 @@ def sample_spatial_and_temporal_locations_for_examples(
122122
t0_datetimes_utc=pd.DatetimeIndex([t0_datetime])
123123
)
124124

125+
# reduce locations to n_examples
126+
if n_examples is not None:
127+
locations = locations[:n_examples]
128+
125129
# find out the number of examples in the last batch,
126130
# we maybe need to duplicate the last example into order to get a full batch
127131
n_examples_last_batch = len(locations) % self.config.process.batch_size

tests/manager/test_manager_live.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,42 @@ def test_create_files_specifying_spatial_and_temporal_locations_of_each_example(
6363
assert len(locations_df) == batch_size
6464

6565

66+
def test_create_files_locations_of_each_example_reduced(
67+
test_configuration_filename,
68+
):
69+
"""Test to create locations, for a reduced number of n_gsps"""
70+
71+
manager = ManagerLive()
72+
manager.load_yaml_configuration(filename=test_configuration_filename)
73+
manager.config.process.batch_size = 5
74+
manager.config.process.split_method = SplitMethod.SAME
75+
manager.initialize_data_sources()
76+
t0_datetime = datetime(2021, 4, 1)
77+
78+
batch_size = manager.config.process.batch_size
79+
80+
with tempfile.TemporaryDirectory() as local_temp_path, tempfile.TemporaryDirectory() as dst_path: # noqa 101
81+
82+
manager.config.output_data.filepath = Path(dst_path)
83+
manager.local_temp_path = Path(local_temp_path)
84+
85+
manager.create_files_specifying_spatial_and_temporal_locations_of_each_example(
86+
t0_datetime=t0_datetime, n_gsps=7
87+
) # noqa 101
88+
89+
live_file = f"{dst_path}/live/{SPATIAL_AND_TEMPORAL_LOCATIONS_OF_EACH_EXAMPLE_FILENAME}"
90+
91+
assert os.path.exists(live_file)
92+
locations_df = pd.read_csv(live_file)
93+
# we've asked for 7 examples, but the batchsize is 5, so we get 10 examples,
94+
# the last 3 being filled in
95+
assert len(locations_df) == batch_size * 2
96+
# check last 3 are filled in from the first one
97+
assert locations_df.iloc[-3].id == 1
98+
assert locations_df.iloc[-2].id == 1
99+
assert locations_df.iloc[-1].id == 1
100+
101+
66102
def test_batches(test_configuration_filename, sat, gsp):
67103
"""Test that batches can be made"""
68104

0 commit comments

Comments
 (0)