Skip to content

Commit 7bfe840

Browse files
authored
Merge pull request #7 from entity-toolkit/1.1.0rc
1.1.0 Release Candidate
2 parents fbfe8ad + f45c7fa commit 7bfe840

File tree

9 files changed

+123
-8
lines changed

9 files changed

+123
-8
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ You can also create a movie of a single field quantity (can be custom):
7979
(data.fields.Ex * data.fields.Bx).sel(x=slice(None, 0.2)).movie.plot(name="ExBx", vmin=-0.01, vmax=0.01, cmap="BrBG")
8080
```
8181

82+
For particles, one can also make 2D phase-space plots:
83+
84+
```python
85+
data.particles[1].sel(t=1.0, method="nearest").particles.phaseplot(x="x", y="uy", xnbins=100, ynbins=200, xlims=(0, 100), cmap="inferno")
86+
```
87+
8288
You may also combine different quantities and plots (e.g., fields & particles) to produce a more customized movie:
8389

8490
```python

dist/nt2py-1.1.0-py3-none-any.whl

35.5 KB
Binary file not shown.

dist/nt2py-1.1.0.tar.gz

30.2 KB
Binary file not shown.

nt2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.0.1"
1+
__version__ = "1.1.0"
22

33
import nt2.containers.data as nt2_data
44

nt2/containers/data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from typing import Callable, Any
22

33
import sys
4+
45
if sys.version_info >= (3, 12):
56
from typing import override
67
else:
8+
79
def override(method):
810
return method
911

12+
1013
from collections.abc import KeysView
1114
from nt2.utils import ToHumanReadable
1215

@@ -25,7 +28,7 @@ def override(method):
2528
from nt2.containers.particles import Particles
2629

2730
import nt2.plotters.polar as acc_polar
28-
31+
import nt2.plotters.particles as acc_particles
2932
import nt2.plotters.inspect as acc_inspect
3033
import nt2.plotters.movie as acc_movie
3134
from nt2.plotters.export import makeFramesAndMovie
@@ -37,6 +40,12 @@ class DatasetPolarPlotAccessor(acc_polar.ds_accessor):
3740
pass
3841

3942

43+
@xr.register_dataset_accessor("particles")
44+
@InheritClassDocstring
45+
class DatasetParticlesPlotAccessor(acc_particles.ds_accessor):
46+
pass
47+
48+
4049
@xr.register_dataarray_accessor("polar")
4150
@InheritClassDocstring
4251
class PolarPlotAccessor(acc_polar.accessor):

nt2/plotters/inspect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def plot(
180180
)
181181

182182
def plot_func(ti: int, _):
183-
if len(self._obj.dims) == 1:
183+
if len(self._obj.isel(t=ti).dims) == 1:
184184
_ = self.plot_frame_1d(
185185
self._obj.isel(t=ti),
186186
None,
@@ -189,7 +189,7 @@ def plot_func(ti: int, _):
189189
fig_kwargs,
190190
plot_kwargs,
191191
)
192-
elif len(self._obj.dims) == 2:
192+
elif len(self._obj.isel(t=ti).dims) == 2:
193193
_ = self.plot_frame_2d(
194194
self._obj.isel(t=ti),
195195
None,

nt2/plotters/particles.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import xarray as xr
2+
import numpy as np
3+
4+
5+
class ds_accessor:
6+
def __init__(self, xarray_obj: xr.Dataset):
7+
self._obj: xr.Dataset = xarray_obj
8+
9+
def phaseplot(
10+
self,
11+
x: str = "x",
12+
y: str = "ux",
13+
xbins: None | np.ndarray = None,
14+
ybins: None | np.ndarray = None,
15+
xlims: None | tuple[float] = None,
16+
ylims: None | tuple[float] = None,
17+
xnbins: int = 100,
18+
ynbins: int = 100,
19+
**kwargs,
20+
):
21+
"""
22+
Create a 2D histogram (phase plot) of two variables in the dataset.
23+
24+
Parameters
25+
----------
26+
x : str
27+
The variable name for the x-axis (default: "x").
28+
y : str
29+
The variable name for the y-axis (default: "ux").
30+
xbins : np.ndarray, optional
31+
The bin edges for the x-axis. If None, 100 bins between min and max of x are used.
32+
ybins : np.ndarray, optional
33+
The bin edges for the y-axis. If None, 100 bins between min and max of y are used.
34+
xlims : tuple[float], optional
35+
The limits for the x-axis. If None, the limits are determined from the data.
36+
ylims : tuple[float], optional
37+
The limits for the y-axis. If None, the limits are determined from the data.
38+
xnbins : int, optional
39+
The number of bins for the x-axis if xbins is None (default: 100).
40+
ynbins : int, optional
41+
The number of bins for the y-axis if ybins is None (default: 100).
42+
**kwargs
43+
Additional keyword arguments passed to matplotlib's pcolormesh.
44+
45+
Raises
46+
------
47+
AssertionError
48+
If x or y are not valid variable names in the dataset, or if the dataset has a time dimension.
49+
50+
Returns
51+
-------
52+
None
53+
54+
Examples
55+
--------
56+
>>> ds.phaseplot(x='x', y='ux', xbins=np.linspace(0, 1000, 100), ybins=np.linspace(-5, 5, 50))
57+
"""
58+
assert x in list(self._obj.keys()) and y in list(
59+
self._obj.keys()
60+
), "x and y must be valid variable names in the dataset"
61+
assert (
62+
len(self._obj[x].dims) == 1 and len(self._obj[y].dims) == 1
63+
), "x and y must be 1D variables"
64+
assert "t" not in self._obj.dims, "Dataset must not have time dimension"
65+
66+
import matplotlib.pyplot as plt
67+
68+
if xbins is None:
69+
if xlims is not None:
70+
xbins_ = np.linspace(xlims[0], xlims[1], xnbins)
71+
else:
72+
xbins_ = np.linspace(
73+
self._obj[x].values.min(), self._obj[x].values.max(), xnbins
74+
)
75+
else:
76+
xbins_ = xbins
77+
if ybins is None:
78+
if ylims is not None:
79+
ybins_ = np.linspace(ylims[0], ylims[1], ynbins)
80+
else:
81+
ybins_ = np.linspace(
82+
self._obj[y].values.min(), self._obj[y].values.max(), ynbins
83+
)
84+
else:
85+
ybins_ = ybins
86+
87+
cnt, _, _ = np.histogram2d(
88+
self._obj[x].values, self._obj[y].values, bins=[xbins_, ybins_]
89+
)
90+
xbins_ = 0.5 * (xbins_[1:] + xbins_[:-1])
91+
ybins_ = 0.5 * (ybins_[1:] + ybins_[:-1])
92+
93+
ax = kwargs.pop("ax", plt.gca())
94+
ax.pcolormesh(
95+
xbins_,
96+
ybins_,
97+
cnt.T,
98+
rasterized=True,
99+
**kwargs,
100+
)
101+
ax.set_xlabel(x)
102+
ax.set_ylabel(y)

nt2/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ def DataIs2DPolar(ds: xr.Dataset) -> bool:
9393
bool
9494
True if the dataset is 2D polar, False otherwise.
9595
"""
96-
return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len(
97-
ds.dims
98-
) == 2
96+
return ("r" in ds.dims and "th" in ds.dims) and len(ds.dims) == 2
9997

10098

10199
def InheritClassDocstring(cls: type) -> type:

shell.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pkgs.mkShell {
99
pkgs."python${py}"
1010
pkgs."python${py}Packages".pip
1111
black
12-
basedpyright
12+
pyright
1313
taplo
1414
vscode-langservers-extracted
1515
zlib

0 commit comments

Comments
 (0)