|
| 1 | +import xarray as xr |
| 2 | +import numpy as np |
| 3 | + |
| 4 | + |
| 5 | +class ds_accessor: |
| 6 | + def __init__(self, xarray_obj: xr.Dataset): |
| 7 | + self._obj: xr.Dataset = xarray_obj |
| 8 | + |
| 9 | + def phaseplot( |
| 10 | + self, |
| 11 | + x: str = "x", |
| 12 | + y: str = "ux", |
| 13 | + xbins: None | np.ndarray = None, |
| 14 | + ybins: None | np.ndarray = None, |
| 15 | + xlims: None | tuple[float] = None, |
| 16 | + ylims: None | tuple[float] = None, |
| 17 | + xnbins: int = 100, |
| 18 | + ynbins: int = 100, |
| 19 | + **kwargs, |
| 20 | + ): |
| 21 | + """ |
| 22 | + Create a 2D histogram (phase plot) of two variables in the dataset. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + x : str |
| 27 | + The variable name for the x-axis (default: "x"). |
| 28 | + y : str |
| 29 | + The variable name for the y-axis (default: "ux"). |
| 30 | + xbins : np.ndarray, optional |
| 31 | + The bin edges for the x-axis. If None, 100 bins between min and max of x are used. |
| 32 | + ybins : np.ndarray, optional |
| 33 | + The bin edges for the y-axis. If None, 100 bins between min and max of y are used. |
| 34 | + xlims : tuple[float], optional |
| 35 | + The limits for the x-axis. If None, the limits are determined from the data. |
| 36 | + ylims : tuple[float], optional |
| 37 | + The limits for the y-axis. If None, the limits are determined from the data. |
| 38 | + xnbins : int, optional |
| 39 | + The number of bins for the x-axis if xbins is None (default: 100). |
| 40 | + ynbins : int, optional |
| 41 | + The number of bins for the y-axis if ybins is None (default: 100). |
| 42 | + **kwargs |
| 43 | + Additional keyword arguments passed to matplotlib's pcolormesh. |
| 44 | +
|
| 45 | + Raises |
| 46 | + ------ |
| 47 | + AssertionError |
| 48 | + If x or y are not valid variable names in the dataset, or if the dataset has a time dimension. |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + None |
| 53 | +
|
| 54 | + Examples |
| 55 | + -------- |
| 56 | + >>> ds.phaseplot(x='x', y='ux', xbins=np.linspace(0, 1000, 100), ybins=np.linspace(-5, 5, 50)) |
| 57 | + """ |
| 58 | + assert x in list(self._obj.keys()) and y in list( |
| 59 | + self._obj.keys() |
| 60 | + ), "x and y must be valid variable names in the dataset" |
| 61 | + assert ( |
| 62 | + len(self._obj[x].dims) == 1 and len(self._obj[y].dims) == 1 |
| 63 | + ), "x and y must be 1D variables" |
| 64 | + assert "t" not in self._obj.dims, "Dataset must not have time dimension" |
| 65 | + |
| 66 | + import matplotlib.pyplot as plt |
| 67 | + |
| 68 | + if xbins is None: |
| 69 | + if xlims is not None: |
| 70 | + xbins_ = np.linspace(xlims[0], xlims[1], xnbins) |
| 71 | + else: |
| 72 | + xbins_ = np.linspace( |
| 73 | + self._obj[x].values.min(), self._obj[x].values.max(), xnbins |
| 74 | + ) |
| 75 | + else: |
| 76 | + xbins_ = xbins |
| 77 | + if ybins is None: |
| 78 | + if ylims is not None: |
| 79 | + ybins_ = np.linspace(ylims[0], ylims[1], ynbins) |
| 80 | + else: |
| 81 | + ybins_ = np.linspace( |
| 82 | + self._obj[y].values.min(), self._obj[y].values.max(), ynbins |
| 83 | + ) |
| 84 | + else: |
| 85 | + ybins_ = ybins |
| 86 | + |
| 87 | + cnt, _, _ = np.histogram2d( |
| 88 | + self._obj[x].values, self._obj[y].values, bins=[xbins_, ybins_] |
| 89 | + ) |
| 90 | + xbins_ = 0.5 * (xbins_[1:] + xbins_[:-1]) |
| 91 | + ybins_ = 0.5 * (ybins_[1:] + ybins_[:-1]) |
| 92 | + |
| 93 | + ax = kwargs.pop("ax", plt.gca()) |
| 94 | + ax.pcolormesh( |
| 95 | + xbins_, |
| 96 | + ybins_, |
| 97 | + cnt.T, |
| 98 | + rasterized=True, |
| 99 | + **kwargs, |
| 100 | + ) |
| 101 | + ax.set_xlabel(x) |
| 102 | + ax.set_ylabel(y) |
0 commit comments