diff --git a/src/titiler/xarray/titiler/xarray/extensions.py b/src/titiler/xarray/titiler/xarray/extensions.py index 0f3515279..cea0884b6 100644 --- a/src/titiler/xarray/titiler/xarray/extensions.py +++ b/src/titiler/xarray/titiler/xarray/extensions.py @@ -1,12 +1,14 @@ """titiler.xarray Extensions.""" import warnings -from typing import Callable, List, Type +from typing import Callable, List, Optional, Type import xarray from attrs import define -from fastapi import Depends +from fastapi import Depends, Query +from rio_tiler.constants import WGS84_CRS from starlette.responses import HTMLResponse +from typing_extensions import Annotated from titiler.core.dependencies import DefaultDependency from titiler.core.factory import FactoryExtension @@ -56,7 +58,7 @@ class DatasetMetadataExtension(FactoryExtension): io_dependency: Type[DefaultDependency] = XarrayIOParams dataset_opener: Callable[..., xarray.Dataset] = xarray_open_dataset - def register(self, factory: TilerFactory): + def register(self, factory: TilerFactory): # noqa: C901 """Register endpoint to the tiler factory.""" @factory.router.get( @@ -109,3 +111,95 @@ def dataset_variables( """Returns the list of keys/variables in the Dataset.""" with self.dataset_opener(src_path, **io_params.as_dict()) as ds: return list(ds.data_vars) + + @factory.router.get( + "/validate", + responses={ + 200: { + "content": { + "application/json": {}, + }, + }, + }, + ) + def validate_dataset( # noqa: C901 + src_path=Depends(factory.path_dependency), + io_params=Depends(self.io_dependency), + variable: Annotated[ + Optional[str], Query(description="Xarray Variable name.") + ] = None, + ): + """Returns the HTML representation of the Xarray Dataset.""" + errors = [] + + with self.dataset_opener(src_path, **io_params.as_dict()) as dst: + variables = list(dst.data_vars) + + if variable: + dst = dst[variable] + + if "x" not in dst.dims and "y" not in dst.dims: + try: + latitude_var_name = next( + name + for name in ["lat", "latitude", "LAT", "LATITUDE", "Lat"] + if name in dst.dims + ) + longitude_var_name = next( + name + for name in ["lon", "longitude", "LON", "LONGITUDE", "Lon"] + if name in dst.dims + ) + + dst = dst.rename( + {latitude_var_name: "y", longitude_var_name: "x"} + ) + + if extra_dims := [d for d in dst.dims if d not in ["x", "y"]]: + dst = dst.transpose(*extra_dims, "y", "x") + else: + dst = dst.transpose("y", "x") + + except StopIteration: + errors.append( + "Dataset does not have compatible spatial coordinates" + ) + + bounds = dst.rio.bounds() + if not bounds: + errors.append("Dataset does not have rioxarray bounds") + + res = dst.rio.resolution() + if not res: + errors.append("Dataset does not have rioxarray resolution") + + if res and bounds: + crs = dst.rio.crs or "epsg:4326" + xres, yres = map(abs, res) + + # Adjust the longitude coordinates to the -180 to 180 range + if crs == "epsg:4326" and (dst.x > 180 + xres / 2).any(): + dst = dst.assign_coords(x=(dst.x + 180) % 360 - 180) + + # Sort the dataset by the updated longitude coordinates + dst = dst.sortby(dst.x) + + bounds = tuple(dst.rio.bounds()) + if crs == WGS84_CRS and ( + bounds[0] + xres / 2 < -180 + or bounds[1] + yres / 2 < -90 + or bounds[2] - xres / 2 > 180 + or bounds[3] - yres / 2 > 90 + ): + errors.append( + "Dataset bounds are not valid, must be in [-180, 180] and [-90, 90]" + ) + + if not dst.rio.transform(): + errors.append("Dataset does not have rioxarray transform") + + return { + "compatible_with_titiler": True if not errors else False, + "errors": errors, + "dataset_vars": variables, + }