Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 6de1337

Browse files
committed
speed up loading, and speed up tests
1 parent 91813dc commit 6de1337

File tree

2 files changed

+49
-17
lines changed

2 files changed

+49
-17
lines changed

nowcasting_dataset/dataset/batch.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from pathlib import Path
77
from typing import Optional, Union
8+
from concurrent import futures
89

910
import xarray as xr
1011
from pydantic import BaseModel, Field
@@ -130,26 +131,45 @@ def save_netcdf(self, batch_i: int, path: Path):
130131
path: the path where it will be saved. This can be local or in the cloud.
131132
132133
"""
133-
for data_source in self.data_sources:
134-
if data_source is not None:
135-
data_source.save_netcdf(batch_i=batch_i, path=path)
134+
135+
with futures.ThreadPoolExecutor() as executor:
136+
# Submit tasks to the executor.
137+
for data_source in self.data_sources:
138+
if data_source is not None:
139+
_ = executor.submit(
140+
data_source.save_netcdf,
141+
batch_i=batch_i,
142+
path=path,
143+
)
144+
# data_source.save_netcdf(batch_i=batch_i, path=path)
136145

137146
@staticmethod
138147
def load_netcdf(local_netcdf_path: Union[Path, str], batch_idx: int):
139148
"""Load batch from netcdf file"""
140149
data_sources_names = Example.__fields__.keys()
141150

142-
# collect data sources
151+
# set up futures executor
143152
batch_dict = {}
144-
for data_source_name in data_sources_names:
145-
146-
local_netcdf_filename = os.path.join(
147-
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
148-
)
149-
if os.path.exists(local_netcdf_filename):
150-
xr_dataset = xr.load_dataset(local_netcdf_filename)
151-
else:
152-
xr_dataset = None
153+
with futures.ThreadPoolExecutor() as executor:
154+
future_examples_per_source = []
155+
156+
# loop over data sources
157+
for data_source_name in data_sources_names:
158+
159+
local_netcdf_filename = os.path.join(
160+
local_netcdf_path, data_source_name, f"{batch_idx}.nc"
161+
)
162+
163+
# submit task
164+
future_examples = executor.submit(
165+
xr.load_dataset,
166+
filename_or_obj=local_netcdf_filename,
167+
)
168+
future_examples_per_source.append([data_source_name, future_examples])
169+
170+
# Collect results from each thread.
171+
for data_source_name, future_examples in future_examples_per_source:
172+
xr_dataset = future_examples.result()
153173

154174
batch_dict[data_source_name] = xr_dataset
155175

tests/dataset/test_batch.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,30 @@
1010

1111
def test_model():
1212

13-
_ = Batch.fake()
13+
con = Configuration()
14+
con.process.batch_size = 4
15+
16+
_ = Batch.fake(configuration=con)
1417

1518

1619
def test_model_save_to_netcdf():
1720

21+
con = Configuration()
22+
con.process.batch_size = 4
23+
1824
with tempfile.TemporaryDirectory() as dirpath:
19-
Batch.fake().save_netcdf(path=dirpath, batch_i=0)
25+
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)
2026

2127
assert os.path.exists(f"{dirpath}/satellite/0.nc")
2228

2329

2430
def test_model_load_from_netcdf():
2531

32+
con = Configuration()
33+
con.process.batch_size = 4
34+
2635
with tempfile.TemporaryDirectory() as dirpath:
27-
Batch.fake().save_netcdf(path=dirpath, batch_i=0)
36+
Batch.fake(configuration=con).save_netcdf(path=dirpath, batch_i=0)
2837

2938
batch = Batch.load_netcdf(batch_idx=0, local_netcdf_path=dirpath)
3039

@@ -33,7 +42,10 @@ def test_model_load_from_netcdf():
3342

3443
def test_batch_to_batch_ml():
3544

36-
_ = BatchML.from_batch(batch=Batch.fake())
45+
con = Configuration()
46+
con.process.batch_size = 4
47+
48+
_ = BatchML.from_batch(batch=Batch.fake(configuration=con))
3749

3850

3951
def test_fake_dataset():

0 commit comments

Comments
 (0)