Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions src/esmf_regrid/esmf_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,35 @@ def _out_dtype(self, in_dtype):
out_dtype = (np.ones(1, dtype=in_dtype) * np.ones(1, dtype=weight_dtype)).dtype
return out_dtype

def _gen_weights_and_data(self, src_array):
extra_shape = src_array.shape[: -self.src.dims]

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
flat_tgt = self.weight_matrix @ flat_src

src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array))
weight_sums = self.weight_matrix @ src_inverted_mask

tgt_data = self.tgt._matrix_to_array(flat_tgt, extra_shape)
tgt_weights = self.tgt._matrix_to_array(weight_sums, extra_shape)
return tgt_weights, tgt_data

def _regrid_from_weights_and_data(self, tgt_weights, tgt_data, norm_type=Constants.NormType.FRACAREA, mdtol=1):
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = tgt_weights > 1 - mdtol
masked_weight_sums = tgt_weights * tgt_mask
normalisations = np.ones_like(tgt_data)
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

tgt_array = tgt_data * normalisations
return tgt_array

def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""Perform regridding on an array of data.

Expand Down Expand Up @@ -195,25 +224,7 @@ def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
f"got an array with shape ending in {main_shape}."
)
raise ValueError(e_msg)
extra_shape = array_shape[: -self.src.dims]
extra_size = max(1, np.prod(extra_shape))
src_inverted_mask = self.src._array_to_matrix(~ma.getmaskarray(src_array))
weight_sums = self.weight_matrix @ src_inverted_mask
out_dtype = self._out_dtype(src_array.dtype)
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = weight_sums > 1 - mdtol
masked_weight_sums = weight_sums * tgt_mask
normalisations = np.ones([self.tgt.size, extra_size], dtype=out_dtype)
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))

flat_src = self.src._array_to_matrix(ma.filled(src_array, 0.0))
flat_tgt = self.weight_matrix @ flat_src
flat_tgt = flat_tgt * normalisations
tgt_array = self.tgt._matrix_to_array(flat_tgt, extra_shape)
tgt_weights, tgt_data = self._gen_weights_and_data(src_array)
tgt_array = self._regrid_from_weights_and_data(tgt_weights, tgt_data, norm_type=norm_type, mdtol=mdtol)
return tgt_array
41 changes: 41 additions & 0 deletions src/esmf_regrid/experimental/_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Provides a regridder class compatible with Partition."""

from esmf_regrid.schemes import (
_ESMFRegridder,
_create_cube,
)

class PartialRegridder(_ESMFRegridder):
def __init__(self, src, tgt, src_slice, tgt_slice, weights, scheme, **kwargs):
self.src_slice = src_slice # this will be tuple-like
self.tgt_slice = tgt_slice
self.scheme = scheme
# TODO: consider disallowing ESMFNearest (unless out of bounds can be made masked)

# Pop duplicate kwargs.
for arg in set(kwargs.keys()).intersection(vars(self.scheme)):
kwargs.pop(arg)

self._regridder = scheme.regridder(
src,
tgt,
precomputed_weights=weights,
**kwargs,
)
self.__dict__.update(self._regridder.__dict__)

def partial_regrid(self, src):
return self.regridder._gen_weights_and_data(src.data)

def finish_regridding(self, src_cube, weights, data):
dims = self._get_cube_dims(src_cube)

result_data = self.regridder._regrid_from_weights_and_data(weights, data)
result_cube = _create_cube(
result_data,
src_cube,
dims,
self._tgt,
len(self._tgt)
)
return result_cube
177 changes: 118 additions & 59 deletions src/esmf_regrid/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
)
from esmf_regrid.experimental._partial import PartialRegridder
from esmf_regrid.schemes import (
ESMFAreaWeighted,
ESMFAreaWeightedRegridder,
ESMFBilinear,
ESMFBilinearRegridder,
ESMFNearest,
ESMFNearestRegridder,
GridRecord,
MeshRecord,
Expand All @@ -28,6 +32,7 @@
ESMFNearestRegridder,
GridToMeshESMFRegridder,
MeshToGridESMFRegridder,
PartialRegridder,
]
_REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}
_SOURCE_NAME = "regridder_source_field"
Expand All @@ -47,6 +52,8 @@
_SOURCE_RESOLUTION = "src_resolution"
_TARGET_RESOLUTION = "tgt_resolution"
_ESMF_ARGS = "esmf_args"
_SRC_SLICE_NAME = "src_slice"
_TGT_SLICE_NAME = "tgt_slice"
_VALID_ESMF_KWARGS = [
"pole_method",
"regrid_pole_npoints",
Expand Down Expand Up @@ -118,54 +125,39 @@ def _clean_var_names(cube):
con.var_name = None


def save_regridder(rg, filename):
"""Save a regridder scheme instance.

Saves any of the regridder classes, i.e.
:class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`,
:class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`,
:class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`,
:class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or
:class:`~esmf_regrid.schemes.ESMFNearestRegridder`.
.

Parameters
----------
rg : :class:`~esmf_regrid.schemes._ESMFRegridder`
The regridder instance to save.
filename : str
The file name to save to.
"""
regridder_type = rg.__class__.__name__

def _standard_grid_cube(grid, name):
if grid[0].ndim == 1:
shape = [coord.points.size for coord in grid]
else:
shape = grid[0].shape
data = np.zeros(shape)
cube = Cube(data, var_name=name, long_name=name)
if grid[0].ndim == 1:
cube.add_dim_coord(grid[0], 0)
cube.add_dim_coord(grid[1], 1)
else:
cube.add_aux_coord(grid[0], [0, 1])
cube.add_aux_coord(grid[1], [0, 1])
return cube

def _standard_mesh_cube(mesh, location, name):
mesh_coords = mesh.to_MeshCoords(location)
data = np.zeros(mesh_coords[0].points.shape[0])
cube = Cube(data, var_name=name, long_name=name)
for coord in mesh_coords:
cube.add_aux_coord(coord, 0)
return cube

def _standard_grid_cube(grid, name):
if grid[0].ndim == 1:
shape = [coord.points.size for coord in grid]
else:
shape = grid[0].shape
data = np.zeros(shape)
cube = Cube(data, var_name=name, long_name=name)
if grid[0].ndim == 1:
cube.add_dim_coord(grid[0], 0)
cube.add_dim_coord(grid[1], 1)
else:
cube.add_aux_coord(grid[0], [0, 1])
cube.add_aux_coord(grid[1], [0, 1])
return cube

def _standard_mesh_cube(mesh, location, name):
mesh_coords = mesh.to_MeshCoords(location)
data = np.zeros(mesh_coords[0].points.shape[0])
cube = Cube(data, var_name=name, long_name=name)
for coord in mesh_coords:
cube.add_aux_coord(coord, 0)
return cube

def _generate_src_tgt(regridder_type, rg, allow_partial):
if regridder_type in [
"ESMFAreaWeightedRegridder",
"ESMFBilinearRegridder",
"ESMFNearestRegridder",
"PartialRegridder",
]:
if regridder_type == "PartialRegridder" and not allow_partial:
e_msg = "To save a PartialRegridder, `allow_partial=True` must be set."
raise ValueError(e_msg)
src_grid = rg._src
if isinstance(src_grid, GridRecord):
src_cube = _standard_grid_cube(
Expand Down Expand Up @@ -210,20 +202,46 @@ def _standard_mesh_cube(mesh, location, name):
tgt_grid = (rg.grid_y, rg.grid_x)
tgt_cube = _standard_grid_cube(tgt_grid, _TARGET_NAME)
_add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME)

else:
e_msg = (
f"Expected a regridder of type `GridToMeshESMFRegridder` or "
f"`MeshToGridESMFRegridder`, got type {regridder_type}."
f"Unexpected regridder type {regridder_type}."
)
raise TypeError(e_msg)
return src_cube, tgt_cube




def save_regridder(rg, filename, allow_partial=False):
"""Save a regridder scheme instance.

Saves any of the regridder classes, i.e.
:class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`,
:class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`,
:class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`,
:class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or
:class:`~esmf_regrid.schemes.ESMFNearestRegridder`.
.

Parameters
----------
rg : :class:`~esmf_regrid.schemes._ESMFRegridder`
The regridder instance to save.
filename : str
The file name to save to.
"""
regridder_type = rg.__class__.__name__

src_cube, tgt_cube = _generate_src_tgt(regridder_type, rg, allow_partial)

method = str(check_method(rg.method).name)

if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]:
resolution = rg.resolution
src_resolution = None
tgt_resolution = None
elif regridder_type == "ESMFAreaWeightedRegridder":
elif method == "CONSERVATIVE":
resolution = None
src_resolution = rg.src_resolution
tgt_resolution = rg.tgt_resolution
Expand Down Expand Up @@ -264,6 +282,19 @@ def _standard_mesh_cube(mesh, location, name):
if tgt_resolution is not None:
attributes[_TARGET_RESOLUTION] = tgt_resolution

extra_cubes = []
if regridder_type == "PartialRegridder":
src_slice = rg.src_slice # this slice is described by a tuple
if src_slice is None:
src_slice = []
src_slice_cube = Cube(src_slice, long_name=_SRC_SLICE_NAME, var_name=_SRC_SLICE_NAME)
tgt_slice = rg.tgt_slice # this slice is described by a tuple
if tgt_slice is None:
tgt_slice = []
tgt_slice_cube = Cube(src_slice, long_name=_TGT_SLICE_NAME, var_name=_TGT_SLICE_NAME)
extra_cubes = [src_slice_cube, tgt_slice_cube]


weights_cube = Cube(weight_data, var_name=_WEIGHTS_NAME, long_name=_WEIGHTS_NAME)
row_coord = AuxCoord(
weight_rows, var_name=_WEIGHTS_ROW_NAME, long_name=_WEIGHTS_ROW_NAME
Expand Down Expand Up @@ -298,15 +329,15 @@ def _standard_mesh_cube(mesh, location, name):

# Save cubes while ensuring var_names do not conflict for the sake of consistency.
with _managed_var_name(src_cube, tgt_cube):
cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube])
cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube, *extra_cubes])

for cube in cube_list:
cube.attributes = attributes

iris.fileformats.netcdf.save(cube_list, filename)


def load_regridder(filename):
def load_regridder(filename, allow_partial=False):
"""Load a regridder scheme instance.

Loads any of the regridder classes, i.e.
Expand Down Expand Up @@ -343,6 +374,10 @@ def load_regridder(filename):
raise TypeError(e_msg)
scheme = _REGRIDDER_NAME_MAP[regridder_type]

if regridder_type == "PartialRegridder" and not allow_partial:
e_msg = "PartialRegridder cannot be loaded without setting `allow_partial=True`."
raise ValueError(e_msg)

# Determine the regridding method, allowing for files created when
# conservative regridding was the only method.
method_string = weights_cube.attributes.get(_METHOD, "CONSERVATIVE")
Expand Down Expand Up @@ -396,26 +431,50 @@ def load_regridder(filename):
elif scheme is MeshToGridESMFRegridder:
resolution_keyword = _TARGET_RESOLUTION
kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol}
elif scheme is ESMFAreaWeightedRegridder:
# elif scheme is ESMFAreaWeightedRegridder:
elif method is Constants.Method.CONSERVATIVE:
kwargs = {
_SOURCE_RESOLUTION: src_resolution,
_TARGET_RESOLUTION: tgt_resolution,
"mdtol": mdtol,
}
elif scheme is ESMFBilinearRegridder:
# elif scheme is ESMFBilinearRegridder:
elif method is Constants.Method.BILINEAR:
kwargs = {"mdtol": mdtol}
else:
kwargs = {}

regridder = scheme(
src_cube,
tgt_cube,
precomputed_weights=weight_matrix,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
esmf_args=esmf_args,
**kwargs,
)
if scheme is PartialRegridder:
src_slice = cubes.extract_cube(_SRC_SLICE_NAME).data.tolist()
if src_slice == []:
src_slice = None
tgt_slice = cubes.extract_cube(_TGT_SLICE_NAME).data.tolist()
if tgt_slice == []:
tgt_slice = None
sub_scheme = {
Constants.Method.CONSERVATIVE: ESMFAreaWeighted,
Constants.Method.BILINEAR: ESMFBilinear,
Constants.Method.NEAREST: ESMFNearest,
}[method]
regridder = scheme(
src_cube,
tgt_cube,
src_slice,
tgt_slice,
weight_matrix,
sub_scheme(),
**kwargs,
)
else:
regridder = scheme(
src_cube,
tgt_cube,
precomputed_weights=weight_matrix,
use_src_mask=use_src_mask,
use_tgt_mask=use_tgt_mask,
esmf_args=esmf_args,
**kwargs,
)

esmf_version = weights_cube.attributes[_VERSION_ESMF]
regridder.regridder.esmf_version = esmf_version
Expand Down
Loading
Loading