From 2819be9d6233817e133f7cb70746e4e15a6d45c4 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 18 Mar 2023 08:46:19 -0700 Subject: [PATCH 1/3] add get_breadcrumb() doc str --- crystal_toolkit/helpers/layouts.py | 30 ++++++++++++---------- crystal_toolkit/helpers/povray_renderer.py | 4 +-- crystal_toolkit/renderables/volumetric.py | 16 +++++++++--- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/crystal_toolkit/helpers/layouts.py b/crystal_toolkit/helpers/layouts.py index 8e9376e9..82490a72 100644 --- a/crystal_toolkit/helpers/layouts.py +++ b/crystal_toolkit/helpers/layouts.py @@ -439,20 +439,24 @@ def __init__(self, *args, **kwargs) -> None: def get_breadcrumb(parts): + """Create a breadcrumb navigation bar. + + Args: + parts (dict): Dictionary of name, link pairs. + + Returns: + html.Nav: Breadcrumb navigation bar. + """ if not parts: - return html.Div() + return html.Nav() - breadcrumbs = html.Nav( - html.Ul( - [ - html.Li( - dcc.Link(name, href=link), - className=(None if idx != len(parts) - 1 else "is-active"), - ) - for idx, (name, link) in enumerate(parts.items()) - ] - ), - className="breadcrumb", - ) + links = [ + html.Li( + dcc.Link(name, href=link), + className="is-active" if idx == len(parts) - 1 else None, + ) + for idx, (name, link) in enumerate(parts.items()) + ] + breadcrumbs = html.Nav(html.Ul(links), className="breadcrumb") return breadcrumbs diff --git a/crystal_toolkit/helpers/povray_renderer.py b/crystal_toolkit/helpers/povray_renderer.py index 83c863e6..4ac67aa3 100644 --- a/crystal_toolkit/helpers/povray_renderer.py +++ b/crystal_toolkit/helpers/povray_renderer.py @@ -135,8 +135,8 @@ def pov_write_data(input_scene_comp, fstream): - """parse a primitive display object in crystaltoolkit and print it to POV-Ray input_scene_comp - fstream. + """Parse a primitive display object in crystaltoolkit and print it to POV-Ray + input_scene_comp fstream. """ vect = "{:.4f},{:.4f},{:.4f}" diff --git a/crystal_toolkit/renderables/volumetric.py b/crystal_toolkit/renderables/volumetric.py index c7b59cff..f43351db 100644 --- a/crystal_toolkit/renderables/volumetric.py +++ b/crystal_toolkit/renderables/volumetric.py @@ -1,6 +1,9 @@ from __future__ import annotations +from typing import Any + import numpy as np +from numpy.typing import ArrayLike from pymatgen.io.vasp import VolumetricData from crystal_toolkit.core.scene import Scene, Surface @@ -9,15 +12,22 @@ def get_isosurface_scene( - self, data_key="total", isolvl=0.05, step_size=4, origin=None, **kwargs -): + self, + data_key: str = "total", + isolvl: float = 0.05, + step_size: int = 4, + origin: ArrayLike = None, + **kwargs: Any, +) -> Scene: """Get the isosurface from a VolumetricData object. Args: data_key (str, optional): Use the volumetric data from self.data[data_key]. Defaults to 'total'. isolvl (float, optional): The cutoff for the isosurface to using the same units as VESTA so - e/bohr and kept grid size independent + e/bohr and kept grid size independent step_size (int, optional): step_size parameter for marching_cubes_lewiner. Defaults to 3. + origin (ArrayLike, optional): The origin of the isosurface. Defaults to None. + **kwargs: Passed to the Surface object. Returns: Scene: object containing the isosurface component From 0ed2233bc11fe7eb88bbc80281646bdb00a6c455 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 18 Mar 2023 17:27:26 -0700 Subject: [PATCH 2/3] add crystal_toolkit/apps/examples/relaxation_trajectory.py --- .../apps/examples/relaxation_trajectory.py | 85 +++++++++++++++++++ crystal_toolkit/components/structure.py | 5 +- 2 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 crystal_toolkit/apps/examples/relaxation_trajectory.py diff --git a/crystal_toolkit/apps/examples/relaxation_trajectory.py b/crystal_toolkit/apps/examples/relaxation_trajectory.py new file mode 100644 index 00000000..3a03d0f9 --- /dev/null +++ b/crystal_toolkit/apps/examples/relaxation_trajectory.py @@ -0,0 +1,85 @@ +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from dash import Dash, dcc, html +from dash.dependencies import Input, Output +from pymatgen.core import Structure +from pymatgen.ext.matproj import MPRester + +import crystal_toolkit.components as ctc +from crystal_toolkit.settings import SETTINGS + +mp_id = "mp-1033715" +with MPRester(monty_decode=False) as mpr: + [task_doc] = mpr.tasks.search(task_ids=[mp_id]) + + +steps = [ + (Structure.from_dict(step["structure"]), step["e_fr_energy"]) + for calc in reversed(task_doc.calcs_reversed) + for step in calc.output["ionic_steps"] +] +struct_traj, energies = zip(*steps) +assert len(steps) == 99 + +e_col = "energy (eV/atom)" +spg_col = "spacegroup" +df_traj = pd.DataFrame( + {e_col: energies, spg_col: [s.get_space_group_info() for s in struct_traj]} +) + + +def plot_energy(df: pd.DataFrame, step: int) -> go.Figure: + """Plot energy as a function of relaxation step.""" + href = f"https://materialsproject.org/materials/{mp_id}" + title = f"{mp_id} - {spg_col} = {df[spg_col][step]}" + fig = px.line(df, y=e_col, template="plotly_white", title=title) + fig.add_vline(x=step, line=dict(dash="dash", width=1)) + return fig + + +struct_comp = ctc.StructureMoleculeComponent( + id="structure", struct_or_mol=struct_traj[0] +) + +step_size = max(1, len(struct_traj) // 20) # ensure slider has max 20 steps +slider = dcc.Slider( + id="slider", + min=0, + max=len(struct_traj) - 1, + value=0, + step=step_size, + updatemode="drag", +) + +graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"}) + +app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH) +app.layout = html.Div( + [ + html.H1( + "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em") + ), + html.P("Drag slider to see structure at different relaxation steps."), + slider, + html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")), + ], + style=dict( + margin="2em auto", placeItems="center", textAlign="center", maxWidth="1000px" + ), +) + +ctc.register_crystal_toolkit(app=app, layout=app.layout) + + +@app.callback( + Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value") +) +def update_structure(step: int) -> tuple[Structure, go.Figure]: + """Update the structure displayed in the StructureMoleculeComponent and the + dashed vertical line in the figure when the slider is moved. + """ + return struct_traj[step], plot_energy(df_traj, step) + + +app.run_server(port=8050, debug=True) diff --git a/crystal_toolkit/components/structure.py b/crystal_toolkit/components/structure.py index 78812e81..119911cd 100644 --- a/crystal_toolkit/components/structure.py +++ b/crystal_toolkit/components/structure.py @@ -84,8 +84,9 @@ class StructureMoleculeComponent(MPComponent): def __init__( self, - struct_or_mol: None - | (Structure | StructureGraph | Molecule | MoleculeGraph) = None, + struct_or_mol: ( + None | Structure | StructureGraph | Molecule | MoleculeGraph + ) = None, id: str = None, className: str = "box", scene_additions: Scene | None = None, From 24c8c1b3ce8fa2d706649c14b857ce9d339d1fb6 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 22 Mar 2023 19:33:17 -0700 Subject: [PATCH 3/3] add normed force mean over atoms line to relaxation plot clean up code and styles --- .../apps/examples/relaxation_trajectory.py | 110 ++++++++++++------ 1 file changed, 77 insertions(+), 33 deletions(-) diff --git a/crystal_toolkit/apps/examples/relaxation_trajectory.py b/crystal_toolkit/apps/examples/relaxation_trajectory.py index 3a03d0f9..fd62536b 100644 --- a/crystal_toolkit/apps/examples/relaxation_trajectory.py +++ b/crystal_toolkit/apps/examples/relaxation_trajectory.py @@ -1,5 +1,7 @@ +import sys + +import numpy as np import pandas as pd -import plotly.express as px import plotly.graph_objects as go from dash import Dash, dcc, html from dash.dependencies import Input, Output @@ -13,46 +15,81 @@ with MPRester(monty_decode=False) as mpr: [task_doc] = mpr.tasks.search(task_ids=[mp_id]) - steps = [ - (Structure.from_dict(step["structure"]), step["e_fr_energy"]) + ( + Structure.from_dict(step["structure"]), + step["e_fr_energy"], + np.linalg.norm(step["forces"], axis=1).mean(), + ) for calc in reversed(task_doc.calcs_reversed) for step in calc.output["ionic_steps"] ] -struct_traj, energies = zip(*steps) assert len(steps) == 99 -e_col = "energy (eV/atom)" -spg_col = "spacegroup" -df_traj = pd.DataFrame( - {e_col: energies, spg_col: [s.get_space_group_info() for s in struct_traj]} -) - +e_col = "Energy (eV)" +force_col = "Force (eV/Å)" +spg_col = "Spacegroup" +struct_col = "Structure" + +df_traj = pd.DataFrame(steps, columns=[struct_col, e_col, force_col]) +df_traj[spg_col] = df_traj[struct_col].map(Structure.get_space_group_info) + + +def plot_energy_and_forces( + df: pd.DataFrame, + step: int, + e_col: str, + force_col: str, + title: str, +) -> go.Figure: + """Plot energy and forces as a function of relaxation step.""" + fig = go.Figure() + # energy trace = primary y-axis + fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy")) + + # forces trace = secondary y-axis + fig.add_trace( + go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2") + ) + + fig.update_layout( + template="plotly_white", + title=title, + xaxis={"title": "Relaxation Step"}, + yaxis={"title": e_col}, + yaxis2={"title": force_col, "overlaying": "y", "side": "right"}, + legend=dict(yanchor="top", y=1, xanchor="right", x=1), + ) + + # vertical line at the specified step + fig.add_vline(x=step, line={"dash": "dash", "width": 1}) -def plot_energy(df: pd.DataFrame, step: int) -> go.Figure: - """Plot energy as a function of relaxation step.""" - href = f"https://materialsproject.org/materials/{mp_id}" - title = f"{mp_id} - {spg_col} = {df[spg_col][step]}" - fig = px.line(df, y=e_col, template="plotly_white", title=title) - fig.add_vline(x=step, line=dict(dash="dash", width=1)) return fig -struct_comp = ctc.StructureMoleculeComponent( - id="structure", struct_or_mol=struct_traj[0] -) +if "struct_comp" not in locals(): + struct_comp = ctc.StructureMoleculeComponent( + id="structure", struct_or_mol=df_traj[struct_col][0] + ) -step_size = max(1, len(struct_traj) // 20) # ensure slider has max 20 steps +step_size = max(1, len(steps) // 20) # ensure slider has max 20 steps slider = dcc.Slider( - id="slider", - min=0, - max=len(struct_traj) - 1, - value=0, - step=step_size, - updatemode="drag", + id="slider", min=0, max=len(steps) - 1, value=0, step=step_size, updatemode="drag" ) -graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"}) + +def make_title(spg: tuple[str, int]) -> str: + """Return a title for the figure.""" + href = f"https://materialsproject.org/materials/{mp_id}/" + return f"{mp_id} - {spg[0]} ({spg[1]})" + + +title = make_title(df_traj[spg_col][0]) +graph = dcc.Graph( + id="fig", + figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title), + style={"maxWidth": "50%"}, +) app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH) app.layout = html.Div( @@ -62,11 +99,12 @@ def plot_energy(df: pd.DataFrame, step: int) -> go.Figure: ), html.P("Drag slider to see structure at different relaxation steps."), slider, - html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")), + html.Div( + [struct_comp.layout(), graph], + style=dict(display="flex", gap="2em", placeContent="center"), + ), ], - style=dict( - margin="2em auto", placeItems="center", textAlign="center", maxWidth="1000px" - ), + style=dict(margin="auto", textAlign="center", maxWidth="1000px", padding="2em"), ) ctc.register_crystal_toolkit(app=app, layout=app.layout) @@ -79,7 +117,13 @@ def update_structure(step: int) -> tuple[Structure, go.Figure]: """Update the structure displayed in the StructureMoleculeComponent and the dashed vertical line in the figure when the slider is moved. """ - return struct_traj[step], plot_energy(df_traj, step) + title = make_title(df_traj[spg_col][step]) + fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title) + + return df_traj[struct_col][step], fig + +# https://stackoverflow.com/a/74918941 +is_jupyter = "ipykernel" in sys.modules -app.run_server(port=8050, debug=True) +app.run(port=8050, debug=True, use_reloader=not is_jupyter)