|
6 | 6 | from pathlib import Path |
7 | 7 |
|
8 | 8 | import numpy as np |
| 9 | +import xarray as xr |
| 10 | +from xarray.ufuncs import isinf, isnan |
9 | 11 |
|
10 | 12 | from nowcasting_dataset.dataset.xr_utils import PydanticXArrayDataSet |
11 | 13 | from nowcasting_dataset.filesystem.utils import makedirs |
@@ -50,6 +52,54 @@ def save_netcdf(self, batch_i: int, path: Path): |
50 | 52 | encoding = {name: {"compression": "lzf"} for name in self.data_vars} |
51 | 53 | self.to_netcdf(local_filename, engine="h5netcdf", mode="w", encoding=encoding) |
52 | 54 |
|
| 55 | + def check_nan_and_inf(self, data: xr.Dataset, variable_name: str = None): |
| 56 | + """ Check that all values are non NaNs and not infinite""" |
| 57 | + |
| 58 | + if isnan(data).any(): |
| 59 | + message = f"Some {self.__class__.__name__} data values are NaNs" |
| 60 | + message += f" ({variable_name})" if variable_name is not None else None |
| 61 | + logger.error(message) |
| 62 | + raise Exception(message) |
| 63 | + |
| 64 | + if isinf(data).any(): |
| 65 | + message = f"Some {self.__class__.__name__} data values are Infinite" |
| 66 | + message += f" ({variable_name})" if variable_name is not None else None |
| 67 | + logger.error(message) |
| 68 | + raise Exception(message) |
| 69 | + |
| 70 | + def check_dataset_greater_than_or_equal_to( |
| 71 | + self, data: xr.Dataset, min_value: int, variable_name: str = None |
| 72 | + ): |
| 73 | + """ Check data is greater than a certain value """ |
| 74 | + if (data < min_value).any(): |
| 75 | + message = f"Some {self.__class__.__name__} data values are less than {min_value}" |
| 76 | + message += f" ({variable_name})" if variable_name is not None else None |
| 77 | + logger.error(message) |
| 78 | + raise Exception(message) |
| 79 | + |
| 80 | + def check_dataset_less_than_or_equal_to( |
| 81 | + self, data: xr.Dataset, max_value: int, variable_name: str = None |
| 82 | + ): |
| 83 | + """ Check data is less than a certain value """ |
| 84 | + if (data > max_value).any(): |
| 85 | + message = f"Some {self.__class__.__name__} data values are less than {max_value}" |
| 86 | + message += f" ({variable_name})" if variable_name is not None else None |
| 87 | + logger.error(message) |
| 88 | + raise Exception(message) |
| 89 | + |
| 90 | + def check_dataset_not_equal( |
| 91 | + self, data: xr.Dataset, value: int, raise_error: bool = True, variable_name: str = None |
| 92 | + ): |
| 93 | + """ Check data is not equal than a certain value """ |
| 94 | + if np.isclose(data, value).any(): |
| 95 | + message = f"Some {self.__class__.__name__} data values are equal to {value}" |
| 96 | + message += f" ({variable_name})" if variable_name is not None else None |
| 97 | + if raise_error: |
| 98 | + logger.error(message) |
| 99 | + raise Exception(message) |
| 100 | + else: |
| 101 | + logger.warning(message) |
| 102 | + |
53 | 103 |
|
54 | 104 | def pad_nans(array, pad_width) -> np.ndarray: |
55 | 105 | """Pad nans with nans""" |
|
0 commit comments