|
| 1 | +from collections.abc import Callable, Mapping, Sequence |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +import xarray as xr |
| 5 | + |
| 6 | +from reformatters.common.logging import get_logger |
| 7 | +from reformatters.common.region_job import ( |
| 8 | + CoordinateValueOrRange, |
| 9 | + RegionJob, |
| 10 | + SourceFileCoord, |
| 11 | +) |
| 12 | +from reformatters.common.storage import StoreFactory |
| 13 | +from reformatters.common.types import ( |
| 14 | + AppendDim, |
| 15 | + ArrayFloat32, |
| 16 | + DatetimeLike, |
| 17 | + Dim, |
| 18 | +) |
| 19 | + |
| 20 | +from .template_config import DwdIconEuDataVar |
| 21 | + |
| 22 | +log = get_logger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +class DwdIconEuForecastSourceFileCoord(SourceFileCoord): |
| 26 | + """Coordinates of a single source file to process.""" |
| 27 | + |
| 28 | + def get_url(self) -> str: |
| 29 | + raise NotImplementedError("Return the URL of the source file.") |
| 30 | + |
| 31 | + def out_loc( |
| 32 | + self, |
| 33 | + ) -> Mapping[Dim, CoordinateValueOrRange]: |
| 34 | + """ |
| 35 | + Returns a data array indexer which identifies the region in the output dataset |
| 36 | + to write the data from the source file. The indexer is a dict from dimension |
| 37 | + names to coordinate values or slices. |
| 38 | + """ |
| 39 | + # If the names of the coordinate attributes of your SourceFileCoord subclass are also all |
| 40 | + # dimension names in the output dataset (e.g. init_time and lead_time), |
| 41 | + # delete this implementation and use the default implementation of this method. |
| 42 | + # |
| 43 | + # Examples where you would override this method: |
| 44 | + # - An analysis dataset created from forecast data: |
| 45 | + # return {"time": self.init_time + self.lead_time} |
| 46 | + return super().out_loc() |
| 47 | + |
| 48 | + |
| 49 | +class DwdIconEuForecastRegionJob( |
| 50 | + RegionJob[DwdIconEuDataVar, DwdIconEuForecastSourceFileCoord] |
| 51 | +): |
| 52 | + # Optionally, limit the number of variables downloaded together. |
| 53 | + # If set to a value less than len(data_vars), downloading, reading/recompressing, |
| 54 | + # and uploading steps will be pipelined within a region job. |
| 55 | + # 5 is a reasonable default if it is possible to download less than all |
| 56 | + # variables in a single file (e.g. you have a grib index). |
| 57 | + # Leave unset if you have to download a whole file to get one variable out |
| 58 | + # to avoid re-downloading the same file multiple times. |
| 59 | + # |
| 60 | + # max_vars_per_download_group: ClassVar[int | None] = None |
| 61 | + |
| 62 | + # Implement this method only if different variables must be retrieved from different urls |
| 63 | + # |
| 64 | + # # @classmethod |
| 65 | + # def source_groups( |
| 66 | + # cls, |
| 67 | + # data_vars: Sequence[DwdIconEuDataVar], |
| 68 | + # ) -> Sequence[Sequence[DwdIconEuDataVar]]: |
| 69 | + # """ |
| 70 | + # Return groups of variables, where all variables in a group can be retrieived from the same source file. |
| 71 | + # """ |
| 72 | + # grouped = defaultdict(list) |
| 73 | + # for data_var in data_vars: |
| 74 | + # grouped[data_var.internal_attrs.file_type].append(data_var) |
| 75 | + # return list(grouped.values()) |
| 76 | + |
| 77 | + # Implement this method only if specific post processing in this dataset |
| 78 | + # requires data from outside the region defined by self.region, |
| 79 | + # e.g. for deaccumulation or interpolation along append_dim in an analysis dataset. |
| 80 | + # |
| 81 | + # def get_processing_region(self) -> slice: |
| 82 | + # """ |
| 83 | + # Return a slice of integer offsets into self.template_ds along self.append_dim that identifies |
| 84 | + # the region to process. In most cases this is exactly self.region, but if additional data outside |
| 85 | + # the region is required, for example for correct interpolation or deaccumulation, this method can |
| 86 | + # return a modified slice (e.g. `slice(self.region.start - 1, self.region.stop + 1)`). |
| 87 | + # """ |
| 88 | + # return self.region |
| 89 | + |
| 90 | + def generate_source_file_coords( |
| 91 | + self, |
| 92 | + processing_region_ds: xr.Dataset, |
| 93 | + data_var_group: Sequence[DwdIconEuDataVar], |
| 94 | + ) -> Sequence[DwdIconEuForecastSourceFileCoord]: |
| 95 | + """Return a sequence of coords, one for each source file required to process the data covered by processing_region_ds.""" |
| 96 | + # return [ |
| 97 | + # DwdIconEuForecastSourceFileCoord( |
| 98 | + # init_time=init_time, |
| 99 | + # lead_time=lead_time, |
| 100 | + # ) |
| 101 | + # for init_time, lead_time in itertools.product( |
| 102 | + # processing_region_ds["init_time"].values, |
| 103 | + # processing_region_ds["lead_time"].values, |
| 104 | + # ) |
| 105 | + # ] |
| 106 | + raise NotImplementedError( |
| 107 | + "Return a sequence of SourceFileCoord objects, one for each source file required to process the data covered by processing_region_ds." |
| 108 | + ) |
| 109 | + |
| 110 | + def download_file(self, coord: DwdIconEuForecastSourceFileCoord) -> Path: |
| 111 | + """Download the file for the given coordinate and return the local path.""" |
| 112 | + # return http_download_to_disk(coord.get_url(), self.dataset_id) |
| 113 | + raise NotImplementedError( |
| 114 | + "Download the file for the given coordinate and return the local path." |
| 115 | + ) |
| 116 | + |
| 117 | + def read_data( |
| 118 | + self, |
| 119 | + coord: DwdIconEuForecastSourceFileCoord, |
| 120 | + data_var: DwdIconEuDataVar, |
| 121 | + ) -> ArrayFloat32: |
| 122 | + """Read and return an array of data for the given variable and source file coordinate.""" |
| 123 | + # with rasterio.open(coord.downloaded_file_path) as reader: |
| 124 | + # TODO: make a band index based on tag matching utility function |
| 125 | + # matching_indexes = [ |
| 126 | + # i |
| 127 | + # for i in range(reader.count) |
| 128 | + # if (tags := reader.tags(i))["GRIB_ELEMENT"] |
| 129 | + # == data_var.internal_attrs.grib_element |
| 130 | + # and tags["GRIB_COMMENT"] == data_var.internal_attrs.grib_comment |
| 131 | + # ] |
| 132 | + # assert len(matching_indexes) == 1, f"Expected exactly 1 matching band, found {matching_indexes}. {data_var.internal_attrs.grib_element=}, {data_var.internal_attrs.grib_description=}, {coord.downloaded_file_path=}" |
| 133 | + # rasterio_band_index = 1 + matching_indexes[0] # rasterio is 1-indexed |
| 134 | + # return reader.read(rasterio_band_index, dtype=np.float32) |
| 135 | + raise NotImplementedError( |
| 136 | + "Read and return data for the given variable and source file coordinate." |
| 137 | + ) |
| 138 | + |
| 139 | + # Implement this to apply transformations to the array (e.g. deaccumulation) |
| 140 | + # |
| 141 | + # def apply_data_transformations( |
| 142 | + # self, data_array: xr.DataArray, data_var: DwdIconEuDataVar |
| 143 | + # ) -> None: |
| 144 | + # """ |
| 145 | + # Apply in-place data transformations to the output data array for a given data variable. |
| 146 | + |
| 147 | + # This method is called after reading all data for a variable into the shared-memory array, |
| 148 | + # and before writing shards to the output store. The default implementation applies binary |
| 149 | + # rounding to float32 arrays if `data_var.internal_attrs.keep_mantissa_bits` is set. |
| 150 | + |
| 151 | + # Subclasses may override this method to implement additional transformations such as |
| 152 | + # deaccumulation, interpolation or other custom logic. All transformations should be |
| 153 | + # performed in-place (don't copy `data_array`, it's large). |
| 154 | + |
| 155 | + # Parameters |
| 156 | + # ---------- |
| 157 | + # data_array : xr.DataArray |
| 158 | + # The output data array to be transformed in-place. |
| 159 | + # data_var : DwdIconEuDataVar |
| 160 | + # The data variable metadata object, which may contain transformation parameters. |
| 161 | + # """ |
| 162 | + # super().apply_data_transformations(data_array, data_var) |
| 163 | + |
| 164 | + def update_template_with_results( |
| 165 | + self, process_results: Mapping[str, Sequence[DwdIconEuForecastSourceFileCoord]] |
| 166 | + ) -> xr.Dataset: |
| 167 | + """ |
| 168 | + Update template dataset based on processing results. This method is called |
| 169 | + during operational updates. |
| 170 | +
|
| 171 | + Subclasses should implement this method to apply dataset-specific adjustments |
| 172 | + based on the processing results. Examples include: |
| 173 | + - Trimming dataset along append_dim to only include successfully processed data |
| 174 | + - Loading existing coordinate values from the primary store and updating them based on results |
| 175 | + - Updating metadata based on what was actually processed vs what was planned |
| 176 | +
|
| 177 | + The default implementation trims along append_dim to end at the most recent |
| 178 | + successfully processed coordinate (timestamp). |
| 179 | +
|
| 180 | + Parameters |
| 181 | + ---------- |
| 182 | + process_results : Mapping[str, Sequence[DwdIconEuForecastSourceFileCoord]] |
| 183 | + Mapping from variable names to their source file coordinates with final processing status. |
| 184 | +
|
| 185 | + Returns |
| 186 | + ------- |
| 187 | + xr.Dataset |
| 188 | + Updated template dataset reflecting the actual processing results. |
| 189 | + """ |
| 190 | + # The super() implementation looks like this: |
| 191 | + # |
| 192 | + # max_append_dim_processed = max( |
| 193 | + # ( |
| 194 | + # c.out_loc()[self.append_dim] # type: ignore[type-var] |
| 195 | + # for c in chain.from_iterable(process_results.values()) |
| 196 | + # if c.status == SourceFileStatus.Succeeded |
| 197 | + # ), |
| 198 | + # default=None, |
| 199 | + # ) |
| 200 | + # if max_append_dim_processed is None: |
| 201 | + # # No data was processed, trim the template to stop before this job's region |
| 202 | + # # This is using isel's exclusive slice end behavior |
| 203 | + # return self.template_ds.isel( |
| 204 | + # {self.append_dim: slice(None, self.region.start)} |
| 205 | + # ) |
| 206 | + # else: |
| 207 | + # return self.template_ds.sel( |
| 208 | + # {self.append_dim: slice(None, max_append_dim_processed)} |
| 209 | + # ) |
| 210 | + # |
| 211 | + # If you like the above behavior, skip implementing this method. |
| 212 | + # If you need to customize the behavior, implement this method. |
| 213 | + |
| 214 | + raise NotImplementedError( |
| 215 | + "Subclasses implement update_template_with_results() with dataset-specific logic" |
| 216 | + ) |
| 217 | + |
| 218 | + @classmethod |
| 219 | + def operational_update_jobs( |
| 220 | + cls, |
| 221 | + primary_store_factory: StoreFactory, |
| 222 | + tmp_store: Path, |
| 223 | + get_template_fn: Callable[[DatetimeLike], xr.Dataset], |
| 224 | + append_dim: AppendDim, |
| 225 | + all_data_vars: Sequence[DwdIconEuDataVar], |
| 226 | + reformat_job_name: str, |
| 227 | + ) -> tuple[ |
| 228 | + Sequence["RegionJob[DwdIconEuDataVar, DwdIconEuForecastSourceFileCoord]"], |
| 229 | + xr.Dataset, |
| 230 | + ]: |
| 231 | + """ |
| 232 | + Return the sequence of RegionJob instances necessary to update the dataset |
| 233 | + from its current state to include the latest available data. |
| 234 | +
|
| 235 | + Also return the template_ds, expanded along append_dim through the end of |
| 236 | + the data to process. The dataset returned here may extend beyond the |
| 237 | + available data at the source, in which case `update_template_with_results` |
| 238 | + will trim the dataset to the actual data processed. |
| 239 | +
|
| 240 | + The exact logic is dataset-specific, but it generally follows this pattern: |
| 241 | + 1. Figure out the range of time to process: append_dim_start (inclusive) and append_dim_end (exclusive) |
| 242 | + a. Read existing data from the primary store to determine what's already processed |
| 243 | + b. Optionally identify recent incomplete/non-final data for reprocessing |
| 244 | + 2. Call get_template_fn(append_dim_end) to get the template_ds |
| 245 | + 3. Create RegionJob instances by calling cls.get_jobs(..., filter_start=append_dim_start) |
| 246 | +
|
| 247 | + Parameters |
| 248 | + ---------- |
| 249 | + primary_store_factory : StoreFactory |
| 250 | + The factory to get the primary store to read existing data from and write updates to. |
| 251 | + tmp_store : Path |
| 252 | + The temporary Zarr store to write into while processing. |
| 253 | + get_template_fn : Callable[[DatetimeLike], xr.Dataset] |
| 254 | + Function to get the template_ds for the operational update. |
| 255 | + append_dim : AppendDim |
| 256 | + The dimension along which data is appended (e.g., "time"). |
| 257 | + all_data_vars : Sequence[DwdIconEuDataVar] |
| 258 | + Sequence of all data variable configs for this dataset. |
| 259 | + reformat_job_name : str |
| 260 | + The name of the reformatting job, used for progress tracking. |
| 261 | + This is often the name of the Kubernetes job, or "local". |
| 262 | +
|
| 263 | + Returns |
| 264 | + ------- |
| 265 | + Sequence[RegionJob[DwdIconEuDataVar, DwdIconEuForecastSourceFileCoord]] |
| 266 | + RegionJob instances that need processing for operational updates. |
| 267 | + xr.Dataset |
| 268 | + The template_ds for the operational update. |
| 269 | + """ |
| 270 | + # existing_ds = xr.open_zarr(primary_store_factory.store()) |
| 271 | + # append_dim_start = existing_ds[append_dim].max() |
| 272 | + # append_dim_end = pd.Timestamp.now() |
| 273 | + # template_ds = get_template_fn(append_dim_end) |
| 274 | + |
| 275 | + # jobs = cls.get_jobs( |
| 276 | + # kind="operational-update", |
| 277 | + # primary_store_factory=primary_store_factory, |
| 278 | + # tmp_store=tmp_store, |
| 279 | + # template_ds=template_ds, |
| 280 | + # append_dim=append_dim, |
| 281 | + # all_data_vars=all_data_vars, |
| 282 | + # reformat_job_name=reformat_job_name, |
| 283 | + # filter_start=append_dim_start, |
| 284 | + # ) |
| 285 | + # return jobs, template_ds |
| 286 | + |
| 287 | + raise NotImplementedError( |
| 288 | + "Subclasses implement operational_update_jobs() with dataset-specific logic" |
| 289 | + ) |
0 commit comments