Skip to content

Commit 63e2fef

Browse files
authored
StoreFactory refactor (#191)
* Add storage.py and update storage config references * Remove final_store implement use of factory * Refactor template_utils.write_metadata to take a StoreFactory * Remove outdated comment
1 parent 5e1d01f commit 63e2fef

File tree

24 files changed

+319
-191
lines changed

24 files changed

+319
-191
lines changed

src/reformatters/__main__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
import reformatters.noaa.gefs.forecast_35_day.cli as noaa_gefs_forecast_35_day
1010
from reformatters.common import deploy
1111
from reformatters.common.config import Config
12-
from reformatters.common.dynamical_dataset import (
13-
DynamicalDataset,
14-
DynamicalDatasetStorageConfig,
15-
)
12+
from reformatters.common.dynamical_dataset import DynamicalDataset
13+
from reformatters.common.storage import DatasetFormat, StorageConfig
1614
from reformatters.contrib.noaa.ndvi_cdr.analysis import (
1715
NoaaNdviCdrAnalysisDataset,
1816
)
@@ -24,21 +22,23 @@
2422
)
2523

2624

27-
class SourceCoopDatasetStorageConfig(DynamicalDatasetStorageConfig):
25+
class SourceCoopDatasetStorageConfig(StorageConfig):
2826
"""Configuration for the storage of a SourceCoop dataset."""
2927

3028
base_path: str = "s3://us-west-2.opendata.source.coop/dynamical"
3129
k8s_secret_names: Sequence[str] = ["source-coop-key"]
30+
format: DatasetFormat = DatasetFormat.ZARR3
3231

3332

34-
class UpstreamGriddedZarrsDatasetStorageConfig(DynamicalDatasetStorageConfig):
33+
class UpstreamGriddedZarrsDatasetStorageConfig(StorageConfig):
3534
"""Configuration for storage in the Upstream gridded zarrs bucket."""
3635

3736
# This bucket is actually an R2 bucket.
3837
# The R2 endpoint URL is stored within our k8s secret and will be set
3938
# when it's imported into the env.
4039
base_path: str = "s3://upstream-gridded-zarrs"
4140
k8s_secret_names: Sequence[str] = ["upstream-gridded-zarrs-key"]
41+
format: DatasetFormat = DatasetFormat.ZARR3
4242

4343

4444
# Registry of all DynamicalDatasets.

src/reformatters/common/dynamical_dataset.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import sentry_sdk
1212
import typer
1313
import xarray as xr
14-
import zarr
1514
from pydantic import computed_field
1615

1716
from reformatters.common import docker, template_utils, validation
@@ -27,13 +26,12 @@
2726
from reformatters.common.logging import get_logger
2827
from reformatters.common.pydantic import FrozenBaseModel
2928
from reformatters.common.region_job import RegionJob, SourceFileCoord
29+
from reformatters.common.storage import StorageConfig, StoreFactory
3030
from reformatters.common.template_config import TemplateConfig
3131
from reformatters.common.types import DatetimeLike
3232
from reformatters.common.zarr import (
3333
copy_zarr_metadata,
3434
get_local_tmp_store,
35-
get_mode,
36-
get_zarr_store,
3735
)
3836

3937
DATA_VAR = TypeVar("DATA_VAR", bound=DataVar[Any])
@@ -42,20 +40,22 @@
4240
logger = get_logger(__name__)
4341

4442

45-
class DynamicalDatasetStorageConfig(FrozenBaseModel):
46-
"""Configuration for the storage of a dataset in production."""
47-
48-
base_path: str
49-
k8s_secret_names: Sequence[str] = []
50-
51-
5243
class DynamicalDataset(FrozenBaseModel, Generic[DATA_VAR, SOURCE_FILE_COORD]):
5344
"""Top level class managing a dataset configuration and processing."""
5445

5546
template_config: TemplateConfig[DATA_VAR]
5647
region_job_class: type[RegionJob[DATA_VAR, SOURCE_FILE_COORD]]
5748

58-
storage_config: DynamicalDatasetStorageConfig
49+
storage_config: StorageConfig
50+
51+
@computed_field # type: ignore[prop-decorator]
52+
@property
53+
def primary_store_factory(self) -> StoreFactory:
54+
return StoreFactory(
55+
storage_config=self.storage_config,
56+
dataset_id=self.dataset_id,
57+
template_config_version=self.template_config.version,
58+
)
5959

6060
def operational_kubernetes_resources(self, image_tag: str) -> Iterable[CronJob]:
6161
"""
@@ -127,27 +127,29 @@ def update(
127127
) -> None:
128128
"""Update an existing dataset with the latest data."""
129129
with self._monitor(ReformatCronJob, reformat_job_name):
130-
final_store = self._final_store()
131130
tmp_store = self._tmp_store()
132131

133132
jobs, template_ds = self.region_job_class.operational_update_jobs(
134-
final_store=final_store,
133+
primary_store_factory=self.primary_store_factory,
135134
tmp_store=tmp_store,
136135
get_template_fn=self._get_template,
137136
append_dim=self.template_config.append_dim,
138137
all_data_vars=self.template_config.data_vars,
139138
reformat_job_name=reformat_job_name,
140139
)
141-
template_utils.write_metadata(template_ds, tmp_store, get_mode(tmp_store))
140+
template_utils.write_metadata(template_ds, tmp_store)
141+
142142
for job in jobs:
143143
process_results = job.process()
144144
updated_template = job.update_template_with_results(process_results)
145-
template_utils.write_metadata(
146-
updated_template, tmp_store, get_mode(tmp_store)
147-
)
148-
copy_zarr_metadata(updated_template, tmp_store, final_store)
145+
# overwrite the tmp store metadata with updated template
146+
template_utils.write_metadata(updated_template, tmp_store)
147+
primary_store = self.primary_store_factory.store()
148+
copy_zarr_metadata(updated_template, tmp_store, primary_store)
149149

150-
logger.info(f"Operational update complete. Wrote to store: {final_store}")
150+
logger.info(
151+
f"Operational update complete. Wrote to store: {self.primary_store_factory.store()}"
152+
)
151153

152154
def backfill_kubernetes(
153155
self,
@@ -164,15 +166,12 @@ def backfill_kubernetes(
164166
image_tag = docker_image or docker.build_and_push_image()
165167

166168
template_ds = self._get_template(append_dim_end)
167-
final_store = self._final_store()
168-
logger.info(f"Writing zarr metadata to {final_store}")
169-
170-
template_utils.write_metadata(template_ds, final_store, get_mode(final_store))
169+
template_utils.write_metadata(template_ds, self.primary_store_factory)
171170

172171
num_jobs = len(
173172
self.region_job_class.get_jobs(
174173
kind="backfill",
175-
final_store=final_store,
174+
primary_store_factory=self.primary_store_factory,
176175
tmp_store=self._tmp_store(),
177176
template_ds=template_ds,
178177
append_dim=self.template_config.append_dim,
@@ -259,9 +258,7 @@ def backfill_local(
259258
) -> None:
260259
"""Run dataset reformatting locally in this process."""
261260
template_ds = self._get_template(append_dim_end)
262-
final_store = self._final_store()
263-
264-
template_utils.write_metadata(template_ds, final_store, get_mode(final_store))
261+
template_utils.write_metadata(template_ds, self.primary_store_factory)
265262

266263
self.process_backfill_region_jobs(
267264
append_dim_end,
@@ -273,7 +270,7 @@ def backfill_local(
273270
filter_contains=filter_contains,
274271
filter_variable_names=filter_variable_names,
275272
)
276-
logger.info(f"Done writing to {final_store}")
273+
logger.info(f"Done writing to {self.primary_store_factory.store()}")
277274

278275
def process_backfill_region_jobs(
279276
self,
@@ -291,7 +288,7 @@ def process_backfill_region_jobs(
291288

292289
region_jobs = self.region_job_class.get_jobs(
293290
kind="backfill",
294-
final_store=self._final_store(),
291+
primary_store_factory=self.primary_store_factory,
295292
tmp_store=self._tmp_store(),
296293
template_ds=self._get_template(append_dim_end),
297294
append_dim=self.template_config.append_dim,
@@ -320,7 +317,7 @@ def validate_dataset(
320317
) -> None:
321318
"""Validate the dataset, raising an exception if it is invalid."""
322319
with self._monitor(ValidationCronJob, reformat_job_name):
323-
store = self._final_store()
320+
store = self.primary_store_factory.store()
324321
validation.validate_dataset(store, validators=self.validators())
325322

326323
logger.info(f"Done validating {store}")
@@ -339,13 +336,6 @@ def get_cli(
339336
app.command("validate")(self.validate_dataset)
340337
return app
341338

342-
def _final_store(self) -> zarr.abc.store.Store:
343-
return get_zarr_store(
344-
self.storage_config.base_path,
345-
self.template_config.dataset_id,
346-
self.template_config.version,
347-
)
348-
349339
def _tmp_store(self) -> Path:
350340
return get_local_tmp_store()
351341

src/reformatters/common/region_job.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import pandas as pd
1616
import pydantic
1717
import xarray as xr
18-
import zarr
1918
from pydantic import AfterValidator, Field, computed_field
2019

2120
from reformatters.common import template_utils
@@ -28,6 +27,7 @@
2827
create_data_array_and_template,
2928
)
3029
from reformatters.common.shared_memory_utils import make_shared_buffer, write_shards
30+
from reformatters.common.storage import StoreFactory
3131
from reformatters.common.types import (
3232
AppendDim,
3333
ArrayND,
@@ -36,7 +36,7 @@
3636
Timestamp,
3737
)
3838
from reformatters.common.update_progress_tracker import UpdateProgressTracker
39-
from reformatters.common.zarr import copy_data_var, get_mode
39+
from reformatters.common.zarr import copy_data_var
4040

4141
log = get_logger(__name__)
4242

@@ -101,7 +101,7 @@ def region_slice(s: slice) -> slice:
101101

102102

103103
class RegionJob(pydantic.BaseModel, Generic[DATA_VAR, SOURCE_FILE_COORD]):
104-
final_store: zarr.abc.store.Store
104+
primary_store_factory: StoreFactory
105105
tmp_store: Path
106106
template_ds: xr.Dataset
107107
data_vars: Sequence[DATA_VAR]
@@ -228,7 +228,7 @@ def update_template_with_results(
228228
Subclasses should implement this method to apply dataset-specific adjustments
229229
based on the processing results. Examples include:
230230
- Trimming dataset along append_dim to only include successfully processed data
231-
- Loading existing coordinate values from final_store and updating them based on results
231+
- Loading existing coordinate values from the primary store and updating them based on results
232232
- Updating metadata based on what was actually processed vs what was planned
233233
234234
The default implementation here trims along append_dim to end at the most recent
@@ -266,7 +266,7 @@ def update_template_with_results(
266266
@classmethod
267267
def operational_update_jobs(
268268
cls,
269-
final_store: zarr.abc.store.Store,
269+
primary_store_factory: StoreFactory,
270270
tmp_store: Path,
271271
get_template_fn: Callable[[DatetimeLike], xr.Dataset],
272272
append_dim: AppendDim,
@@ -284,16 +284,16 @@ def operational_update_jobs(
284284
285285
The exact logic is dataset-specific, but it generally follows this pattern:
286286
1. Figure out the range of time to process: append_dim_start (inclusive) and append_dim_end (exclusive)
287-
a. Read existing data from final_store to determine what's already processed
287+
a. Read existing data from the primary store to determine what's already processed
288288
b. Optionally identify recent incomplete/non-final data for reprocessing
289289
2. Call get_template_fn(append_dim_end) to get the template_ds
290290
3. Create RegionJob instances by calling cls.get_jobs(..., filter_start=append_dim_start)
291291
292292
Parameters
293293
----------
294-
final_store : zarr.abc.store.Store
295-
The destination Zarr store to read existing data from and write updates to.
296-
tmp_store : zarr.abc.store.Store | Path
294+
primary_store_factory : StoreFactory
295+
The factory to get the primary store to read existing data from and write updates to.
296+
tmp_store : Path
297297
The temporary Zarr store to write into while processing.
298298
get_template_fn : Callable[[DatetimeLike], xr.Dataset]
299299
Function to get the template_ds for the operational update.
@@ -331,7 +331,7 @@ def dataset_id(self) -> str:
331331
def get_jobs(
332332
cls,
333333
kind: Literal["backfill", "operational-update"],
334-
final_store: zarr.abc.store.Store,
334+
primary_store_factory: StoreFactory,
335335
tmp_store: Path,
336336
template_ds: xr.Dataset,
337337
append_dim: AppendDim,
@@ -357,9 +357,9 @@ def get_jobs(
357357
358358
Parameters
359359
----------
360-
final_store : zarr.abc.store.Store
361-
The destination Zarr store to write into.
362-
tmp_store : zarr.abc.store.Store | Path
360+
primary_store_factory : StoreFactory
361+
The factory to get the primary store to read existing data from and write updates to.
362+
tmp_store : Path
363363
The temporary Zarr store to write into while processing.
364364
template_ds : xr.Dataset
365365
Dataset template defining structure and metadata.
@@ -442,7 +442,7 @@ def get_jobs(
442442

443443
all_jobs = [
444444
cls(
445-
final_store=final_store,
445+
primary_store_factory=primary_store_factory,
446446
tmp_store=tmp_store,
447447
template_ds=template_ds,
448448
data_vars=data_var_group,
@@ -468,7 +468,7 @@ def process(self) -> Mapping[str, Sequence[SOURCE_FILE_COORD]]:
468468
i. Read data from source files into the shared array
469469
ii. Apply any required data transformations (e.g., rounding, deaccumulation)
470470
iii. Write output shards to the tmp_store
471-
iv. Upload chunk data from tmp_store to final_store
471+
iv. Upload chunk data from tmp_store to the primary store
472472
473473
Returns
474474
-------
@@ -477,8 +477,10 @@ def process(self) -> Mapping[str, Sequence[SOURCE_FILE_COORD]]:
477477
"""
478478
processing_region_ds, output_region_ds = self._get_region_datasets()
479479

480+
primary_store = self.primary_store_factory.store()
481+
480482
progress_tracker = UpdateProgressTracker(
481-
self.final_store, self.reformat_job_name, self.region.start
483+
primary_store, self.reformat_job_name, self.region.start
482484
)
483485
data_vars_to_process: Sequence[DATA_VAR] = progress_tracker.get_unprocessed(
484486
self.data_vars
@@ -489,9 +491,7 @@ def process(self) -> Mapping[str, Sequence[SOURCE_FILE_COORD]]:
489491
data_var_groups, self.max_vars_per_download_group
490492
)
491493

492-
template_utils.write_metadata(
493-
self.template_ds, self.tmp_store, get_mode(self.tmp_store)
494-
)
494+
template_utils.write_metadata(self.template_ds, self.tmp_store)
495495

496496
results: dict[str, Sequence[SOURCE_FILE_COORD]] = {}
497497
upload_futures: list[Any] = []
@@ -555,7 +555,7 @@ def process(self) -> Mapping[str, Sequence[SOURCE_FILE_COORD]]:
555555
self.template_ds,
556556
self.append_dim,
557557
self.tmp_store,
558-
self.final_store,
558+
primary_store,
559559
partial(progress_tracker.record_completion, data_var.name),
560560
)
561561
)

0 commit comments

Comments
 (0)