diff --git a/src/esmf_regrid/esmf_regridder.py b/src/esmf_regrid/esmf_regridder.py index 974a8de6..be18e20a 100644 --- a/src/esmf_regrid/esmf_regridder.py +++ b/src/esmf_regrid/esmf_regridder.py @@ -158,6 +158,35 @@ def _out_dtype(self, in_dtype): out_dtype = (np.ones(1, dtype=in_dtype) * np.ones(1, dtype=weight_dtype)).dtype return out_dtype + def _gen_weights_and_data(self, src_array): + extra_shape = src_array.shape[: -self.src.dims] + + flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0)) + flat_tgt = self.weight_matrix @ flat_src + + src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) + weight_sums = self.weight_matrix @ src_inverted_mask + + tgt_data = self.tgt._matrix_to_array(flat_tgt, extra_shape) + tgt_weights = self.tgt._matrix_to_array(weight_sums, extra_shape) + return tgt_weights, tgt_data + + def _regrid_from_weights_and_data(self, tgt_weights, tgt_data, norm_type=Constants.NormType.FRACAREA, mdtol=1): + # Set the minimum mdtol to be slightly higher than 0 to account for rounding + # errors. + mdtol = max(mdtol, 1e-8) + tgt_mask = tgt_weights > 1 - mdtol + masked_weight_sums = tgt_weights * tgt_mask + normalisations = np.ones_like(tgt_data) + if norm_type == Constants.NormType.FRACAREA: + normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] + elif norm_type == Constants.NormType.DSTAREA: + pass + normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask)) + + tgt_array = tgt_data * normalisations + return tgt_array + def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): """Perform regridding on an array of data. @@ -195,25 +224,7 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1): f"got an array with shape ending in {main_shape}." ) raise ValueError(e_msg) - extra_shape = array_shape[: -self.src.dims] - extra_size = max(1, np.prod(extra_shape)) - src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array)) - weight_sums = self.weight_matrix @ src_inverted_mask - out_dtype = self._out_dtype(src_array.dtype) - # Set the minimum mdtol to be slightly higher than 0 to account for rounding - # errors. - mdtol = max(mdtol, 1e-8) - tgt_mask = weight_sums > 1 - mdtol - masked_weight_sums = weight_sums * tgt_mask - normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype) - if norm_type == Constants.NormType.FRACAREA: - normalisations[tgt_mask] /= masked_weight_sums[tgt_mask] - elif norm_type == Constants.NormType.DSTAREA: - pass - normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask)) - flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0)) - flat_tgt = self.weight_matrix @ flat_src - flat_tgt = flat_tgt * normalisations - tgt_array = self.tgt._matrix_to_array(flat_tgt, extra_shape) + tgt_weights, tgt_data = self._gen_weights_and_data(src_array) + tgt_array = self._regrid_from_weights_and_data(tgt_weights, tgt_data, norm_type=norm_type, mdtol=mdtol) return tgt_array diff --git a/src/esmf_regrid/experimental/_partial.py b/src/esmf_regrid/experimental/_partial.py new file mode 100644 index 00000000..c8067158 --- /dev/null +++ b/src/esmf_regrid/experimental/_partial.py @@ -0,0 +1,41 @@ +"""Provides a regridder class compatible with Partition.""" + +from esmf_regrid.schemes import ( + _ESMFRegridder, + _create_cube, +) + +class PartialRegridder(_ESMFRegridder): + def __init__(self, src, tgt, src_slice, tgt_slice, weights, scheme, **kwargs): + self.src_slice = src_slice # this will be tuple-like + self.tgt_slice = tgt_slice + self.scheme = scheme + # TODO: consider disallowing ESMFNearest (unless out of bounds can be made masked) + + # Pop duplicate kwargs. + for arg in set(kwargs.keys()).intersection(vars(self.scheme)): + kwargs.pop(arg) + + self._regridder = scheme.regridder( + src, + tgt, + precomputed_weights=weights, + **kwargs, + ) + self.__dict__.update(self._regridder.__dict__) + + def partial_regrid(self, src): + return self.regridder._gen_weights_and_data(src.data) + + def finish_regridding(self, src_cube, weights, data): + dims = self._get_cube_dims(src_cube) + + result_data = self.regridder._regrid_from_weights_and_data(weights, data) + result_cube = _create_cube( + result_data, + src_cube, + dims, + self._tgt, + len(self._tgt) + ) + return result_cube diff --git a/src/esmf_regrid/experimental/io.py b/src/esmf_regrid/experimental/io.py index e71d7365..9672746d 100644 --- a/src/esmf_regrid/experimental/io.py +++ b/src/esmf_regrid/experimental/io.py @@ -14,9 +14,13 @@ GridToMeshESMFRegridder, MeshToGridESMFRegridder, ) +from esmf_regrid.experimental._partial import PartialRegridder from esmf_regrid.schemes import ( + ESMFAreaWeighted, ESMFAreaWeightedRegridder, + ESMFBilinear, ESMFBilinearRegridder, + ESMFNearest, ESMFNearestRegridder, GridRecord, MeshRecord, @@ -28,6 +32,7 @@ ESMFNearestRegridder, GridToMeshESMFRegridder, MeshToGridESMFRegridder, + PartialRegridder, ] _REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS} _SOURCE_NAME = "regridder_source_field" @@ -47,6 +52,8 @@ _SOURCE_RESOLUTION = "src_resolution" _TARGET_RESOLUTION = "tgt_resolution" _ESMF_ARGS = "esmf_args" +_SRC_SLICE_NAME = "src_slice" +_TGT_SLICE_NAME = "tgt_slice" _VALID_ESMF_KWARGS = [ "pole_method", "regrid_pole_npoints", @@ -118,54 +125,39 @@ def _clean_var_names(cube): con.var_name = None -def save_regridder(rg, filename): - """Save a regridder scheme instance. - - Saves any of the regridder classes, i.e. - :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, - :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, - :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, - :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or - :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. - . - - Parameters - ---------- - rg : :class:`~esmf_regrid.schemes._ESMFRegridder` - The regridder instance to save. - filename : str - The file name to save to. - """ - regridder_type = rg.__class__.__name__ - - def _standard_grid_cube(grid, name): - if grid[0].ndim == 1: - shape = [coord.points.size for coord in grid] - else: - shape = grid[0].shape - data = np.zeros(shape) - cube = Cube(data, var_name=name, long_name=name) - if grid[0].ndim == 1: - cube.add_dim_coord(grid[0], 0) - cube.add_dim_coord(grid[1], 1) - else: - cube.add_aux_coord(grid[0], [0, 1]) - cube.add_aux_coord(grid[1], [0, 1]) - return cube - - def _standard_mesh_cube(mesh, location, name): - mesh_coords = mesh.to_MeshCoords(location) - data = np.zeros(mesh_coords[0].points.shape[0]) - cube = Cube(data, var_name=name, long_name=name) - for coord in mesh_coords: - cube.add_aux_coord(coord, 0) - return cube - +def _standard_grid_cube(grid, name): + if grid[0].ndim == 1: + shape = [coord.points.size for coord in grid] + else: + shape = grid[0].shape + data = np.zeros(shape) + cube = Cube(data, var_name=name, long_name=name) + if grid[0].ndim == 1: + cube.add_dim_coord(grid[0], 0) + cube.add_dim_coord(grid[1], 1) + else: + cube.add_aux_coord(grid[0], [0, 1]) + cube.add_aux_coord(grid[1], [0, 1]) + return cube + +def _standard_mesh_cube(mesh, location, name): + mesh_coords = mesh.to_MeshCoords(location) + data = np.zeros(mesh_coords[0].points.shape[0]) + cube = Cube(data, var_name=name, long_name=name) + for coord in mesh_coords: + cube.add_aux_coord(coord, 0) + return cube + +def _generate_src_tgt(regridder_type, rg, allow_partial): if regridder_type in [ "ESMFAreaWeightedRegridder", "ESMFBilinearRegridder", "ESMFNearestRegridder", + "PartialRegridder", ]: + if regridder_type == "PartialRegridder" and not allow_partial: + e_msg = "To save a PartialRegridder, `allow_partial=True` must be set." + raise ValueError(e_msg) src_grid = rg._src if isinstance(src_grid, GridRecord): src_cube = _standard_grid_cube( @@ -210,12 +202,38 @@ def _standard_mesh_cube(mesh, location, name): tgt_grid = (rg.grid_y, rg.grid_x) tgt_cube = _standard_grid_cube(tgt_grid, _TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME) + else: e_msg = ( - f"Expected a regridder of type `GridToMeshESMFRegridder` or " - f"`MeshToGridESMFRegridder`, got type {regridder_type}." + f"Unexpected regridder type {regridder_type}." ) raise TypeError(e_msg) + return src_cube, tgt_cube + + + + +def save_regridder(rg, filename, allow_partial=False): + """Save a regridder scheme instance. + + Saves any of the regridder classes, i.e. + :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, + :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, + :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, + :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or + :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. + . + + Parameters + ---------- + rg : :class:`~esmf_regrid.schemes._ESMFRegridder` + The regridder instance to save. + filename : str + The file name to save to. + """ + regridder_type = rg.__class__.__name__ + + src_cube, tgt_cube = _generate_src_tgt(regridder_type, rg, allow_partial) method = str(check_method(rg.method).name) @@ -223,7 +241,7 @@ def _standard_mesh_cube(mesh, location, name): resolution = rg.resolution src_resolution = None tgt_resolution = None - elif regridder_type == "ESMFAreaWeightedRegridder": + elif method == "CONSERVATIVE": resolution = None src_resolution = rg.src_resolution tgt_resolution = rg.tgt_resolution @@ -264,6 +282,19 @@ def _standard_mesh_cube(mesh, location, name): if tgt_resolution is not None: attributes[_TARGET_RESOLUTION] = tgt_resolution + extra_cubes = [] + if regridder_type == "PartialRegridder": + src_slice = rg.src_slice # this slice is described by a tuple + if src_slice is None: + src_slice = [] + src_slice_cube = Cube(src_slice, long_name=_SRC_SLICE_NAME, var_name=_SRC_SLICE_NAME) + tgt_slice = rg.tgt_slice # this slice is described by a tuple + if tgt_slice is None: + tgt_slice = [] + tgt_slice_cube = Cube(src_slice, long_name=_TGT_SLICE_NAME, var_name=_TGT_SLICE_NAME) + extra_cubes = [src_slice_cube, tgt_slice_cube] + + weights_cube = Cube(weight_data, var_name=_WEIGHTS_NAME, long_name=_WEIGHTS_NAME) row_coord = AuxCoord( weight_rows, var_name=_WEIGHTS_ROW_NAME, long_name=_WEIGHTS_ROW_NAME @@ -298,7 +329,7 @@ def _standard_mesh_cube(mesh, location, name): # Save cubes while ensuring var_names do not conflict for the sake of consistency. with _managed_var_name(src_cube, tgt_cube): - cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) + cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube, *extra_cubes]) for cube in cube_list: cube.attributes = attributes @@ -306,7 +337,7 @@ def _standard_mesh_cube(mesh, location, name): iris.fileformats.netcdf.save(cube_list, filename) -def load_regridder(filename): +def load_regridder(filename, allow_partial=False): """Load a regridder scheme instance. Loads any of the regridder classes, i.e. @@ -343,6 +374,10 @@ def load_regridder(filename): raise TypeError(e_msg) scheme = _REGRIDDER_NAME_MAP[regridder_type] + if regridder_type == "PartialRegridder" and not allow_partial: + e_msg = "PartialRegridder cannot be loaded without setting `allow_partial=True`." + raise ValueError(e_msg) + # Determine the regridding method, allowing for files created when # conservative regridding was the only method. method_string = weights_cube.attributes.get(_METHOD, "CONSERVATIVE") @@ -396,26 +431,50 @@ def load_regridder(filename): elif scheme is MeshToGridESMFRegridder: resolution_keyword = _TARGET_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} - elif scheme is ESMFAreaWeightedRegridder: + # elif scheme is ESMFAreaWeightedRegridder: + elif method is Constants.Method.CONSERVATIVE: kwargs = { _SOURCE_RESOLUTION: src_resolution, _TARGET_RESOLUTION: tgt_resolution, "mdtol": mdtol, } - elif scheme is ESMFBilinearRegridder: + # elif scheme is ESMFBilinearRegridder: + elif method is Constants.Method.BILINEAR: kwargs = {"mdtol": mdtol} else: kwargs = {} - regridder = scheme( - src_cube, - tgt_cube, - precomputed_weights=weight_matrix, - use_src_mask=use_src_mask, - use_tgt_mask=use_tgt_mask, - esmf_args=esmf_args, - **kwargs, - ) + if scheme is PartialRegridder: + src_slice = cubes.extract_cube(_SRC_SLICE_NAME).data.tolist() + if src_slice == []: + src_slice = None + tgt_slice = cubes.extract_cube(_TGT_SLICE_NAME).data.tolist() + if tgt_slice == []: + tgt_slice = None + sub_scheme = { + Constants.Method.CONSERVATIVE: ESMFAreaWeighted, + Constants.Method.BILINEAR: ESMFBilinear, + Constants.Method.NEAREST: ESMFNearest, + }[method] + regridder = scheme( + src_cube, + tgt_cube, + src_slice, + tgt_slice, + weight_matrix, + sub_scheme(), + **kwargs, + ) + else: + regridder = scheme( + src_cube, + tgt_cube, + precomputed_weights=weight_matrix, + use_src_mask=use_src_mask, + use_tgt_mask=use_tgt_mask, + esmf_args=esmf_args, + **kwargs, + ) esmf_version = weights_cube.attributes[_VERSION_ESMF] regridder.regridder.esmf_version = esmf_version diff --git a/src/esmf_regrid/experimental/partition.py b/src/esmf_regrid/experimental/partition.py new file mode 100644 index 00000000..8dd6ba3e --- /dev/null +++ b/src/esmf_regrid/experimental/partition.py @@ -0,0 +1,308 @@ +"""Provides an interface for splitting up a large regridding task.""" + +import numpy as np +from scipy import sparse + +from esmf_regrid.experimental.io import load_regridder, save_regridder +from esmf_regrid.experimental._partial import PartialRegridder + + +def _interpret_slice(sl): + # return [slice(s) for s in sl] + return np.s_[sl[0][0]:sl[0][1], sl[1][0]:sl[1][1]] + +# TODO: consider if nearest is appropriate with this way of doing things + + +class Partition: + + # TODO: add a way to save the Partition object. + # TODO: make a way to find out which files have been saved from the last session. + + ## hold a list of files + ## hold a collection of source indices + ## alternately hold a collection of chunk indices + ## note which indices are fully loaded + def __init__(self, src, tgt, scheme, file_names, src_chunks, tgt_chunks=None, auto_generate=False, saved_files=None, partially_saved=None): + self.src = src + self.tgt = tgt + self.scheme = scheme + # TODO: consider abstracting away the idea of files + self.file_names = file_names + # TODO: consider deriving this from self.src.lazy_data() + self.src_chunks = src_chunks + assert len(src_chunks) == len(file_names) + self.tgt_chunks = tgt_chunks + assert tgt_chunks is None # We don't handle big targets currently + + # Note: this may need to become more sophisticated when both src and tgt are large + self.file_chunk_dict = {file: chunk for file, chunk in zip(self.file_names, self.src_chunks)} + + self.neighbouring_files = self._find_neighbours() + if saved_files is None: + self.saved_files = [] + else: + self.saved_files = saved_files + if partially_saved is None: + self.partially_saved_files = [] + else: + self.partially_saved_files = partially_saved + if auto_generate: + self.generate_files(self.file_names) + + @property + def unsaved_files(self): + files = list(set(self.file_names) - set(self.saved_files) - set(self.partially_saved_files)) + return [file for file in self.file_names if file in files] + + def generate_files(self, files_to_generate=None): + if files_to_generate is None: + # TODO: consider adding logic to order the files more efficiently. + files = self.partially_saved_files + self.unsaved_files + else: + assert isinstance(files_to_generate, int) + files = (self.partially_saved_files + self.unsaved_files)[:files_to_generate] + + # sort files + files = [file for file in self.file_names if file in files] + + # Do this to ensure the last regridder is saved + files.append(None) + + previous_regridders = [] + previous_files = [] + for file in files: + if file in self.partially_saved_files: + next_regridder = load_regridder(file, allow_partial=True) + elif file is None: + next_regridder = None + else: + src_chunk = self.file_chunk_dict[file] + src = self.src[*_interpret_slice(src_chunk)] + tgt = self.tgt + next_regridder = self.scheme.regridder(src, tgt) + previous_regridders, next_regridder = self._combine_regridders(previous_regridders, next_regridder, previous_files, file) + if previous_regridders: + for regridder, pre_file in zip(previous_regridders, previous_files): + neighbours = self.neighbouring_files[pre_file] + file_complete = True + # TODO: consider any + for neighbour in neighbours: + if neighbour in self.unsaved_files and neighbour != file: + file_complete = False + save_regridder(regridder, pre_file, allow_partial=True) + if file_complete: + self.saved_files.append(pre_file) + else: + self.partially_saved_files.append(pre_file) + + # This will need to be more sophisticated for more complex cases + previous_regridders = [next_regridder] + previous_files = [file] + + + def apply_regridders(self, cube, allow_incomplete=False): + # for each target chunk, iterate through each associated regridder + # for now, assume one target chunk + # TODO: figure out how to mask parts of the target not covered by any source (e.g. start out with full mask) + if not allow_incomplete: + assert len(self.unsaved_files) == 0 + # TODO: this may work better as a cube of the correct shape for more complex cases + current_result = None + files = self.saved_files + + for file in files: + # TODO: make sure this works well with dask + next_regridder = load_regridder(file, allow_partial=True) + cube_slice = next_regridder.src_slice + next_result = next_regridder(cube[*_interpret_slice(cube_slice)]) + current_result = self._combine_results(current_result, next_result) + return current_result + + + def _find_neighbours(self): + # for the simplest case, neighbours will be next to each other in the list + files = self.file_names + neighbours = {file_1: (file_0, file_2) for file_0, file_1, file_2 in zip(files[:-2], files[1:-1], files[2:])} + neighbours.update({files[0]: (files[1],), files[-1]: (files[-2],)}) + return neighbours + + def _combine_regridders(self, existing, current_regridder, pre_files, file): + # For now, combine 2, in future, more regridders may be combined + + if len(existing) == 0: + if not isinstance(current_regridder, PartialRegridder): + src_slice = self.file_chunk_dict[file] + src_cube = self.src[*_interpret_slice(src_slice)] + weights = current_regridder.regridder.weight_matrix + current_regridder = PartialRegridder(src_cube, self.tgt, src_slice, None, weights, self.scheme) + elif current_regridder is None: + previous_regridder, = existing + if not isinstance(previous_regridder, PartialRegridder): + src_slice = self.file_chunk_dict[pre_files[0]] + src_cube = self.src[*_interpret_slice(src_slice)] + weights = previous_regridder.regridder.weight_matrix + previous_regridder = PartialRegridder(src_cube, self.tgt, src_slice, None, weights, self.scheme) + existing = [previous_regridder] + else: + (previous_regridder,) = existing + current_range = current_regridder.regridder.weight_matrix.max(axis=1).nonzero()[0] + previous_range = previous_regridder.regridder.weight_matrix.max(axis=1).nonzero()[0] + mutual_overlaps = np.intersect1d(current_range, previous_range) + + if mutual_overlaps.shape != (0,): + + # Calculate a slice of the current chunk which contains all the overlapping source cells. + overlaps_next = current_regridder.regridder.weight_matrix[mutual_overlaps].nonzero()[1] + h_len = current_regridder._src[0].shape[0] + v_len = current_regridder._src[1].shape[-1] + # TODO: make this more rigorous + tgt_size = np.prod(self.tgt.shape) + buffer = (overlaps_next % v_len).max() + 1 + + # Add this slice to the previous chunk. + pre_file, = pre_files + src_slice = self.file_chunk_dict[pre_file] + # assumes slice has form [[x_start, x_stop], [y_start, y_stop]] + # TODO: consider how this affects file_chunk_dict + src_slice[0][1] += buffer + new_src_cube = self.src[*_interpret_slice(src_slice)] + tgt_slice = None # should describe all valid target indices + + # Create weights for new regridder + previous_wm = previous_regridder.regridder.weight_matrix + current_wm = current_regridder.regridder.weight_matrix + right_wm = sparse.csr_array((tgt_size, buffer * h_len)) + buffer_inds = sum(np.meshgrid(np.arange(buffer), np.arange(h_len) * v_len)).flatten() + right_wm[mutual_overlaps] = current_wm[mutual_overlaps][:, buffer_inds] + new_weight_matrix = _combine_sparse(previous_wm, right_wm, h_len, v_len, buffer, tgt_size) + new_weight_matrix = sparse.csr_matrix( + new_weight_matrix) # must be matrix for ier (likely to change in future) + + # Remove weights from current regridder which have been added to previous regridder. + current_regridder.regridder.weight_matrix[mutual_overlaps] = 0 + + # Construct replacement for previous regridder with new weights and source cube. + previous_regridder = PartialRegridder(new_src_cube, self.tgt, src_slice, tgt_slice, new_weight_matrix, self.scheme) + existing = [previous_regridder] + else: + if not isinstance(current_regridder, PartialRegridder): + src_slice = self.file_chunk_dict[file] + src_cube = self.src[*_interpret_slice(src_slice)] + weights = current_regridder.regridder.weight_matrix + current_regridder = PartialRegridder(src_cube, self.tgt, src_slice, None, weights, self.scheme) + previous_regridder, = existing + if not isinstance(previous_regridder, PartialRegridder): + src_slice = self.file_chunk_dict[pre_files[0]] + src_cube = self.src[*_interpret_slice(src_slice)] + weights = previous_regridder.regridder.weight_matrix + previous_regridder = PartialRegridder(src_cube, self.tgt, src_slice, None, weights, self.scheme) + existing = [previous_regridder] + + return existing, current_regridder + + def _combine_results(self, existing_results, next_result): + # iterate through for each target chunk + # for now, assume one target chunk + if existing_results is None: + combined_result = next_result + else: + # combined_result = existing_results + next_result + combined_data = np.ma.filled(existing_results.data, 0) + np.ma.filled(next_result.data, 0) + combined_mask = np.ma.getmaskarray(existing_results.data) & np.ma.getmaskarray(next_result.data) + combined_result = existing_results.copy() + combined_result.data = np.ma.array(combined_data, mask=combined_mask) + return combined_result + + +def _combine_sparse(left, right, w, a, b, t): + # TODO: make this more clear + result = sparse.csr_array((t, w * (a + b))) + src_indices_left = (np.arange(a)[np.newaxis, :] + ((a + b) * np.arange(w)[:, np.newaxis])).flatten() + left_im = sparse.csr_array((np.ones(a * w), (np.arange(a * w), src_indices_left)), shape=(a * w, w * (a + b))) + src_indices_right = (np.arange(b)[np.newaxis, :] + a + ((a + b) * np.arange(w)[:, np.newaxis])).flatten() + right_im = sparse.csr_array((np.ones(b * w), (np.arange(b * w), src_indices_right)), shape=(b * w, w * (a + b))) + + result_add_left = left @ left_im + result += result_add_left + + result_add_right = right @ right_im + result += result_add_right + return result + +class Partition2: + def __init__(self, src, tgt, scheme, file_names, src_chunks, tgt_chunks=None, auto_generate=False, saved_files=None, + partially_saved=None): + self.src = src + self.tgt = tgt + self.scheme = scheme + # TODO: consider abstracting away the idea of files + self.file_names = file_names + # TODO: consider deriving this from self.src.lazy_data() + self.src_chunks = src_chunks + assert len(src_chunks) == len(file_names) + self.tgt_chunks = tgt_chunks + assert tgt_chunks is None # We don't handle big targets currently + + # Note: this may need to become more sophisticated when both src and tgt are large + self.file_chunk_dict = {file: chunk for file, chunk in zip(self.file_names, self.src_chunks)} + + if saved_files is None: + self.saved_files = [] + else: + self.saved_files = saved_files + if auto_generate: + self.generate_files(self.file_names) + + @property + def unsaved_files(self): + files = set(self.file_names) - set(self.saved_files) + return [file for file in self.file_names if file in files] + + def generate_files(self, files_to_generate=None): + if files_to_generate is None: + # TODO: consider adding logic to order the files more efficiently. + files = self.unsaved_files + else: + assert isinstance(files_to_generate, int) + files = self.unsaved_files[:files_to_generate] + + for file in files: + src_chunk = self.file_chunk_dict[file] + src = self.src[*_interpret_slice(src_chunk)] + tgt = self.tgt + regridder = self.scheme.regridder(src, tgt) + src_slice = self.file_chunk_dict[file] + src_cube = self.src[*_interpret_slice(src_slice)] + weights = regridder.regridder.weight_matrix + regridder = PartialRegridder(src_cube, self.tgt, src_slice, None, weights, self.scheme) + # TODO: make partial? + save_regridder(regridder, file, allow_partial=True) + self.saved_files.append(file) + + def apply_regridders(self, cube, allow_incomplete=False): + # for each target chunk, iterate through each associated regridder + # for now, assume one target chunk + # TODO: figure out how to mask parts of the target not covered by any source (e.g. start out with full mask) + if not allow_incomplete: + assert len(self.unsaved_files) == 0 + # TODO: this may work better as a cube of the correct shape for more complex cases + current_result = None + current_weights = None + files = self.saved_files + + for file, chunk in zip(self.file_names, self.src_chunks): + if file in files: + next_regridder = load_regridder(file, allow_partial=True) + cube_chunk = cube[*_interpret_slice(chunk)] + next_weights, next_result = next_regridder.partial_regrid(cube_chunk) + if current_weights is None: + current_weights = next_weights + else: + current_weights += next_weights + if current_result is None: + current_result = next_result + else: + current_result += next_result + + return next_regridder.finish_regridding(cube_chunk, current_weights, current_result) diff --git a/src/esmf_regrid/schemes.py b/src/esmf_regrid/schemes.py index adfe87b9..4ee27136 100644 --- a/src/esmf_regrid/schemes.py +++ b/src/esmf_regrid/schemes.py @@ -1087,6 +1087,7 @@ def __init__( the regridder is saved . """ + self._method = Constants.Method.CONSERVATIVE if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) @@ -1123,6 +1124,7 @@ def regridder( use_tgt_mask=None, tgt_location="face", esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1191,6 +1193,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location="face", esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1240,6 +1243,7 @@ def __init__( the regridder is saved . """ + self._method = Constants.Method.BILINEAR if not (0 <= mdtol <= 1): msg = "Value for mdtol must be in range 0 - 1, got {}." raise ValueError(msg.format(mdtol)) @@ -1274,6 +1278,7 @@ def regridder( tgt_location=None, extrapolate_gaps=False, esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1336,6 +1341,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1389,6 +1395,7 @@ def __init__( arguments are recorded as a property of this regridder and are stored when the regridder is saved . """ + self._method = Constants.Method.NEAREST self.use_src_mask = use_src_mask self.use_tgt_mask = use_tgt_mask self.tgt_location = tgt_location @@ -1415,6 +1422,7 @@ def regridder( use_tgt_mask=None, tgt_location=None, esmf_args=None, + precomputed_weights=None, ): """Create regridder to perform regridding from ``src_grid`` to ``tgt_grid``. @@ -1468,6 +1476,7 @@ def regridder( use_tgt_mask=use_tgt_mask, tgt_location=tgt_location, esmf_args=esmf_args, + precomputed_weights=precomputed_weights, ) @@ -1566,26 +1575,7 @@ def __init__( else: self._src = GridRecord(_get_coord(src, "x"), _get_coord(src, "y")) - def __call__(self, cube): - """Regrid this :class:`~iris.cube.Cube` onto the target grid of this regridder instance. - - The given :class:`~iris.cube.Cube` must be defined with the same grid as the source - :class:`~iris.cube.Cube` used to create this :class:`_ESMFRegridder` instance. - - Parameters - ---------- - cube : :class:`iris.cube.Cube` - A :class:`~iris.cube.Cube` instance to be regridded. - - Returns - ------- - :class:`iris.cube.Cube` - A :class:`~iris.cube.Cube` defined with the horizontal dimensions of the target - and the other dimensions from this :class:`~iris.cube.Cube`. The data values of - this :class:`~iris.cube.Cube` will be converted to values on the new grid using - regridding via :mod:`esmpy` generated weights. - - """ + def _get_cube_dims(self, cube): if cube.mesh is not None: # TODO: replace temporary hack when iris issues are sorted. # Ignore differences in var_name that might be caused by saving. @@ -1629,6 +1619,29 @@ def __call__(self, cube): else: # Due to structural reasons, the order here must be reversed. dims = cube.coord_dims(new_src_x)[::-1] + return dims + + def __call__(self, cube): + """Regrid this :class:`~iris.cube.Cube` onto the target grid of this regridder instance. + + The given :class:`~iris.cube.Cube` must be defined with the same grid as the source + :class:`~iris.cube.Cube` used to create this :class:`_ESMFRegridder` instance. + + Parameters + ---------- + cube : :class:`iris.cube.Cube` + A :class:`~iris.cube.Cube` instance to be regridded. + + Returns + ------- + :class:`iris.cube.Cube` + A :class:`~iris.cube.Cube` defined with the horizontal dimensions of the target + and the other dimensions from this :class:`~iris.cube.Cube`. The data values of + this :class:`~iris.cube.Cube` will be converted to values on the new grid using + regridding via :mod:`esmpy` generated weights. + + """ + dims = self._get_cube_dims(cube) regrid_info = RegridInfo( dims=dims, diff --git a/src/esmf_regrid/tests/unit/experimental/io/partition/__init__.py b/src/esmf_regrid/tests/unit/experimental/io/partition/__init__.py new file mode 100644 index 00000000..656fc3a9 --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/io/partition/__init__.py @@ -0,0 +1 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" diff --git a/src/esmf_regrid/tests/unit/experimental/io/partition/test_PartialRegridder.py b/src/esmf_regrid/tests/unit/experimental/io/partition/test_PartialRegridder.py new file mode 100644 index 00000000..581e26ee --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/io/partition/test_PartialRegridder.py @@ -0,0 +1,7 @@ +"""Unit tests for :mod:`esmf_regrid.experimental._partial`.""" + +from esmf_regrid.experimental._partial import PartialRegridder + +def test_PartialRegridder(): + #TODO: write unit tests + pass diff --git a/src/esmf_regrid/tests/unit/experimental/io/partition/test_Partition.py b/src/esmf_regrid/tests/unit/experimental/io/partition/test_Partition.py new file mode 100644 index 00000000..09454d7d --- /dev/null +++ b/src/esmf_regrid/tests/unit/experimental/io/partition/test_Partition.py @@ -0,0 +1,60 @@ +"""Unit tests for :mod:`esmf_regrid.experimental.partition`.""" +import numpy as np +from esmf_regrid import ESMFAreaWeighted +from esmf_regrid.experimental.partition import Partition, Partition2 + +from esmf_regrid.tests.unit.schemes.test__cube_to_GridInfo import ( + _curvilinear_cube, + _grid_cube, +) +from esmf_regrid.tests.unit.schemes.test__mesh_to_MeshInfo import ( + _gridlike_mesh_cube, +) + +def test_Partition(tmp_path): + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + src.data = np.arange(150*500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + chunks = [[[0, 100], [0, 150]], [[100, 200], [0, 150]], [[200, 300], [0, 150]], [[300, 400], [0, 150]], [[400, 500], [0, 150]]] + # TODO: consider different ways we could specify this in the API: + # chunks = 5 + # chunks = (5, None) + # chunks = None # (use dask chunks) + + partition = Partition(src, tgt, scheme, files, chunks) + + partition.generate_files() + + result = partition.apply_regridders(src) + + expected = src.regrid(tgt, scheme) + + assert result == expected + assert np.array_equal(result.data, expected.data) + +def test_Partition2(tmp_path): + src = _grid_cube(150, 500, (-180, 180), (-90, 90), circular=True) + src.data = np.arange(150*500).reshape([500, 150]) + tgt = _grid_cube(16, 36, (-180, 180), (-90, 90), circular=True) + + files = [tmp_path / f"partial_{x}.nc" for x in range(5)] + scheme = ESMFAreaWeighted(mdtol=1) + chunks = [[[0, 100], [0, 150]], [[100, 200], [0, 150]], [[200, 300], [0, 150]], [[300, 400], [0, 150]], [[400, 500], [0, 150]]] + # TODO: consider different ways we could specify this in the API: + # chunks = 5 + # chunks = (5, None) + # chunks = None # (use dask chunks) + + partition = Partition2(src, tgt, scheme, files, chunks) + + partition.generate_files() + + result = partition.apply_regridders(src) + + + expected = src.regrid(tgt, scheme) + assert np.allclose(result.data, expected.data) + assert result == expected diff --git a/src/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py b/src/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py index 6ebb1b21..c8fa526a 100644 --- a/src/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py +++ b/src/esmf_regrid/tests/unit/experimental/io/test_round_tripping.py @@ -13,6 +13,7 @@ esmpy, ) from esmf_regrid.esmf_regridder import ESMF_NO_VERSION +from esmf_regrid.experimental._partial import PartialRegridder from esmf_regrid.experimental.io import load_regridder, save_regridder from esmf_regrid.experimental.unstructured_scheme import ( GridToMeshESMFRegridder, @@ -77,7 +78,13 @@ def _make_grid_to_mesh_regridder( } if regridder == GridToMeshESMFRegridder: kwargs["method"] = method - rg = regridder(src, tgt, **kwargs) + if regridder == PartialRegridder: + pre_rg = ESMFAreaWeightedRegridder(src, tgt, **kwargs) + weights = pre_rg.regridder.weight_matrix + scheme = ESMFAreaWeighted() + rg = PartialRegridder(src, tgt, None, None, weights, scheme) + else: + rg = regridder(src, tgt, **kwargs) return rg, src @@ -153,6 +160,7 @@ def _compare_ignoring_var_names(x, y): (Constants.Method.BILINEAR, GridToMeshESMFRegridder), (Constants.Method.NEAREST, GridToMeshESMFRegridder), (None, ESMFAreaWeightedRegridder), + (None, PartialRegridder), ], ) def test_grid_to_mesh_round_trip(tmp_path, method, regridder): @@ -161,8 +169,14 @@ def test_grid_to_mesh_round_trip(tmp_path, method, regridder): method=method, regridder=regridder, circular=True ) filename = tmp_path / "regridder.nc" - save_regridder(original_rg, filename) - loaded_rg = load_regridder(str(filename)) + + if regridder == PartialRegridder: + allow_partial = True + else: + allow_partial = False + + save_regridder(original_rg, filename, allow_partial=allow_partial) + loaded_rg = load_regridder(str(filename), allow_partial=allow_partial) if regridder == GridToMeshESMFRegridder: assert original_rg.location == loaded_rg.location @@ -206,8 +220,8 @@ def test_grid_to_mesh_round_trip(tmp_path, method, regridder): assert original_rg.resolution == loaded_rg.resolution original_res_rg, _ = _make_grid_to_mesh_regridder(method=method, resolution=8) res_filename = tmp_path / "regridder_res.nc" - save_regridder(original_res_rg, res_filename) - loaded_res_rg = load_regridder(str(res_filename)) + save_regridder(original_res_rg, res_filename, allow_partial=allow_partial) + loaded_res_rg = load_regridder(str(res_filename), allow_partial=allow_partial) assert original_res_rg.resolution == loaded_res_rg.resolution assert ( original_res_rg.regridder.src.resolution @@ -232,8 +246,8 @@ def test_grid_to_mesh_round_trip(tmp_path, method, regridder): method=method, regridder=regridder, circular=True ) nc_filename = tmp_path / "non_circular_regridder.nc" - save_regridder(original_nc_rg, nc_filename) - loaded_nc_rg = load_regridder(str(nc_filename)) + save_regridder(original_nc_rg, nc_filename, allow_partial=allow_partial) + loaded_nc_rg = load_regridder(str(nc_filename), allow_partial=allow_partial) if regridder == GridToMeshESMFRegridder: _compare_ignoring_var_names(original_nc_rg.grid_x, loaded_nc_rg.grid_x) _compare_ignoring_var_names(original_nc_rg.grid_y, loaded_nc_rg.grid_y)