diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py
index 296a2947..5ef76d90 100644
--- a/crystal_toolkit/components/phonon.py
+++ b/crystal_toolkit/components/phonon.py
@@ -1,39 +1,41 @@
from __future__ import annotations
import itertools
+from copy import deepcopy
from typing import TYPE_CHECKING, Any
import numpy as np
import plotly.graph_objects as go
from dash import dcc, html
-from dash.dependencies import Component, Input, Output
+from dash.dependencies import Component, Input, Output, State
from dash.exceptions import PreventUpdate
-from dash_mp_components import CrystalToolkitScene
+from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene
+
+# crystal animation algo
+from pymatgen.analysis.graphs import StructureGraph
+from pymatgen.analysis.local_env import CrystalNN
from pymatgen.ext.matproj import MPRester
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
from pymatgen.phonon.dos import CompletePhononDos
from pymatgen.phonon.plotter import PhononBSPlotter
+from pymatgen.transformations.standard_transformations import SupercellTransformation
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.core.panelcomponent import PanelComponent
from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres
-from crystal_toolkit.helpers.layouts import (
- Column,
- Columns,
- Label,
- MessageBody,
- MessageContainer,
- get_data_list,
-)
+from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list
from crystal_toolkit.helpers.pretty_labels import pretty_labels
if TYPE_CHECKING:
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
from pymatgen.electronic_structure.dos import CompleteDos
-# Author: Jason Munro, Janosh Riebesell
-# Contact: jmunro@lbl.gov, janosh@lbl.gov
-
+DISPLACE_COEF = [0, 1, 0, -1, 0]
+MARKER_COLOR = "red"
+MARKER_SIZE = 12
+MARKER_SHAPE = "x"
+MAX_MAGNITUDE = 300
+MIN_MAGNITUDE = 0
# TODOs:
# - look for additional projection methods in phonon DOS (currently only atom
@@ -64,26 +66,32 @@ def __init__(
**kwargs,
)
+ bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos(
+ self.initial_data["default"]
+ )
+ self.create_store("bs-store", bs)
+ self.create_store("bs", None)
+ self.create_store("dos", None)
+
@property
def _sub_layouts(self) -> dict[str, Component]:
# defaults
state = {"label-select": "sc", "dos-select": "ap"}
- bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos(
- self.initial_data["default"]
- )
- fig = PhononBandstructureAndDosComponent.get_figure(bs, dos)
+ fig = PhononBandstructureAndDosComponent.get_figure(None, None)
# Main plot
graph = dcc.Graph(
figure=fig,
config={"displayModeBar": False},
- responsive=True,
+ responsive=False,
id=self.id("ph-bsdos-graph"),
)
# Brillouin zone
- zone_scene = self.get_brillouin_zone_scene(bs)
- zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px")
+ zone_scene = self.get_brillouin_zone_scene(None)
+ zone = CrystalToolkitScene(
+ data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone")
+ )
# Hide by default if not loaded by mpid, switching between k-paths
# on-the-fly only supported for bandstructures retrieved from MP
@@ -105,9 +113,11 @@ def _sub_layouts(self) -> dict[str, Component]:
options=options,
)
],
- style={"width": "200px"}
- if show_path_options
- else {"maxWidth": "200", "display": "none"},
+ style=(
+ {"width": "200px"}
+ if show_path_options
+ else {"maxWidth": "200", "display": "none"}
+ ),
id=self.id("path-container"),
)
@@ -122,9 +132,11 @@ def _sub_layouts(self) -> dict[str, Component]:
options=options,
)
],
- style={"width": "200px"}
- if show_path_options
- else {"width": "200px", "display": "none"},
+ style=(
+ {"width": "200px"}
+ if show_path_options
+ else {"width": "200px", "display": "none"}
+ ),
id=self.id("label-container"),
)
@@ -138,9 +150,82 @@ def _sub_layouts(self) -> dict[str, Component]:
style={"width": "200px"},
)
- summary_dict = self._get_data_list_dict(bs, dos)
+ summary_dict = self._get_data_list_dict(None, None)
summary_table = get_data_list(summary_dict)
+ # crystal visualization
+
+ tip = html.H5(
+ "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!",
+ )
+
+ crystal_animation = html.Div(
+ CrystalToolkitAnimationScene(
+ data={},
+ sceneSize="500px",
+ id=self.id("crystal-animation"),
+ settings={"defaultZoom": 1.2},
+ axisView="SW",
+ showControls=False, # disable download for now
+ ),
+ style={"width": "60%"},
+ )
+
+ crystal_animation_controls = html.Div(
+ [
+ html.Br(),
+ html.Div(tip, style={"textAlign": "center"}),
+ html.Br(),
+ html.H5("Control Panel", style={"textAlign": "center"}),
+ html.H6("Supercell modification"),
+ html.Br(),
+ html.Div(
+ [
+ self.get_numerical_input(
+ kwarg_label="scale-x",
+ default=1,
+ is_int=True,
+ label="x",
+ min=1,
+ style={"width": "5rem"},
+ ),
+ self.get_numerical_input(
+ kwarg_label="scale-y",
+ default=1,
+ is_int=True,
+ label="y",
+ min=1,
+ style={"width": "5rem"},
+ ),
+ self.get_numerical_input(
+ kwarg_label="scale-z",
+ default=1,
+ is_int=True,
+ label="z",
+ min=1,
+ style={"width": "5rem"},
+ ),
+ html.Button(
+ "Update",
+ id=self.id("supercell-controls-btn"),
+ style={"height": "40px"},
+ ),
+ ],
+ style={"display": "flex"},
+ ),
+ html.Br(),
+ html.Div(
+ self.get_slider_input(
+ kwarg_label="magnitude",
+ default=0.5,
+ step=0.01,
+ domain=[0, 1],
+ label="Vibration magnitude",
+ )
+ ),
+ ],
+ )
+
return {
"graph": graph,
"convention": convention,
@@ -148,10 +233,31 @@ def _sub_layouts(self) -> dict[str, Component]:
"label-select": label_select,
"zone": zone,
"table": summary_table,
+ "crystal-animation": crystal_animation,
+ "tip": tip,
+ "crystal-animation-controls": crystal_animation_controls,
}
+ def _get_animation_panel(self):
+ sub_layouts = self._sub_layouts
+ return Columns(
+ [
+ Column(
+ [
+ Columns(
+ [
+ sub_layouts["crystal-animation"],
+ sub_layouts["crystal-animation-controls"],
+ ]
+ )
+ ]
+ ),
+ ]
+ )
+
def layout(self) -> html.Div:
sub_layouts = self._sub_layouts
+ crystal_animation = self._get_animation_panel()
graph = Columns([Column([sub_layouts["graph"]])])
controls = Columns(
[
@@ -166,11 +272,143 @@ def layout(self) -> html.Div:
)
brillouin_zone = Columns(
[
- Column([Label("Summary"), sub_layouts["table"]]),
+ Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")),
Column([Label("Brillouin Zone"), sub_layouts["zone"]]),
]
)
- return html.Div([graph, controls, brillouin_zone])
+
+ return html.Div([graph, crystal_animation, controls, brillouin_zone])
+
+ @staticmethod
+ def _get_eigendisplacement(
+ ph_bs: BandStructureSymmLine,
+ json_data: dict,
+ band: int = 0,
+ qpoint: int = 0,
+ precision: int = 15,
+ magnitude: int = MAX_MAGNITUDE / 2,
+ total_repeat_cell_cnt: int = 1,
+ ) -> dict:
+ if not ph_bs or not json_data:
+ return {}
+
+ assert json_data["contents"][0]["name"] == "atoms"
+ assert json_data["contents"][1]["name"] == "bonds"
+ rdata = deepcopy(json_data)
+
+ def calc_max_displacement(idx: int) -> list:
+ """
+ Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement.
+
+ Parameters:
+ idx (int): The atom index.
+
+ Returns:
+ list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement]
+
+ This function extracts the real component of the atom's eigendisplacement,
+ scales it by the specified magnitude, and returns the resulting vector.
+ """
+
+ # get the atom index
+ assert total_repeat_cell_cnt != 0
+
+ modified_idx = (
+ int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx
+ )
+
+ return [
+ round(complex(vec).real * magnitude, precision)
+ for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx]
+ ]
+
+ def calc_animation_step(max_displacement: list, coef: int) -> list:
+ """
+ Calculate the displacement for an animation frame based on the given coefficient.
+
+ Parameters:
+ max_displacement (list): A list of maximum displacements along each axis,
+ formatted as [x_max_displacement, y_max_displacement, z_max_displacement].
+ coef (int): A coefficient indicating the motion direction.
+ - 0: no movement
+ - 1: forward movement
+ - -1: backward movement
+
+ Returns:
+ list: The displacement vector [x_displacement, y_displacement, z_displacement].
+
+ This function generates oscillatory motion by scaling the maximum displacement
+ with the provided coefficient.
+ """
+ return [round(coef * md, precision) for md in max_displacement]
+
+ # Compute per-frame atomic motion.
+ # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates.
+ contents0 = json_data["contents"][0]["contents"]
+ for cidx, content in enumerate(contents0):
+ max_displacement = calc_max_displacement(content["_meta"][0])
+ rcontent = rdata["contents"][0]["contents"][cidx]
+ # put animation frame to the given atom index
+ rcontent["animate"] = [
+ calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF
+ ]
+ rcontent["keyframes"] = list(range(len(DISPLACE_COEF)))
+ rcontent["animateType"] = "displacement"
+ # Compute per-frame bonding motion.
+ # Explanation:
+ # Each bond connects two atoms, `u` and `v`, represented as (u)----(v)
+ # To model the bond motion, it is divided into two segments:
+ # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v)
+ # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid).
+ # For each cylinder, displacements are assigned to the endpoints — for example,
+ # the (u)--(mid) cylinder uses:
+ # [
+ # [u_x_displacement, u_y_displacement, u_z_displacement],
+ # [mid_x_displacement, mid_y_displacement, mid_z_displacement]
+ # ].
+ contents1 = json_data["contents"][1]["contents"]
+
+ for cidx, content in enumerate(contents1):
+ bond_animation = []
+ assert len(content["_meta"]) == len(content["positionPairs"])
+
+ for atom_idx_pair in content["_meta"]:
+ max_displacements = list(
+ map(calc_max_displacement, atom_idx_pair)
+ ) # max displacement for u and v
+
+ u_to_middle_bond_animation = []
+
+ for coef in DISPLACE_COEF:
+ # Calculate the midpoint displacement between atom u and v for each animation frame.
+ u_displacement, v_displacement = [
+ np.array(calc_animation_step(max_displacement, coef))
+ for max_displacement in max_displacements
+ ]
+ middle_end_displacement = np.add(u_displacement, v_displacement) / 2
+
+ u_to_middle_bond_animation.append(
+ [
+ u_displacement, # u atom displacement
+ [
+ round(dis, precision) for dis in middle_end_displacement
+ ], # middle point displacement
+ ]
+ )
+
+ bond_animation.append(u_to_middle_bond_animation)
+
+ rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation
+ rdata["contents"][1]["contents"][cidx]["keyframes"] = list(
+ range(len(DISPLACE_COEF))
+ )
+ rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement"
+
+ # remove unused sense
+ for i in range(2, 4):
+ rdata["contents"][i]["visible"] = False
+
+ return rdata
@staticmethod
def _get_ph_bs_dos(
@@ -303,6 +541,7 @@ def get_ph_bandstructure_traces(bs, freq_range):
"line": {"color": "#1f77b4"},
"hoverinfo": "skip",
"name": "Total",
+ "customdata": [[di, band_num] for di in range(len(x_dat))],
"hovertemplate": "%{y:.2f} THz",
"showlegend": False,
"xaxis": "x",
@@ -348,6 +587,9 @@ def get_ph_bandstructure_traces(bs, freq_range):
def _get_data_list_dict(
bs: PhononBandStructureSymmLine, dos: CompletePhononDos
) -> dict[str, str | bool | int]:
+ if (not bs) and (not dos):
+ return {}
+
bs_minpoint, bs_min_freq = bs.min_freq()
min_freq_report = (
f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}"
@@ -373,7 +615,7 @@ def _get_data_list_dict(
target="blank",
),
]
- ): "Yes" if bs.has_nac else "No",
+ ): ("Yes" if bs.has_nac else "No"),
"Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No",
"Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No",
"Min frequency": min_freq_report,
@@ -443,14 +685,9 @@ def get_figure(
ph_dos: CompletePhononDos | None = None,
freq_range: tuple[float | None, float | None] = (None, None),
) -> go.Figure:
- if freq_range[0] is None:
- freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1])
-
- if freq_range[1] is None:
- freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05)
-
if (not ph_dos) and (not ph_bs):
empty_plot_style = {
+ "height": 500,
"xaxis": {"visible": False},
"yaxis": {"visible": False},
"paper_bgcolor": "rgba(0,0,0,0)",
@@ -459,6 +696,12 @@ def get_figure(
return go.Figure(layout=empty_plot_style)
+ if freq_range[0] is None:
+ freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1])
+
+ if freq_range[1] is None:
+ freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05)
+
if ph_bs:
(
bs_traces,
@@ -555,7 +798,7 @@ def get_figure(
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(230,230,230,230)",
margin=dict(l=60, b=50, t=50, pad=0, r=30),
- # clickmode="event+select"
+ clickmode="event+select",
)
figure = {"data": bs_traces + dos_traces, "layout": layout}
@@ -580,124 +823,57 @@ def get_figure(
def generate_callbacks(self, app, cache) -> None:
@app.callback(
Output(self.id("ph-bsdos-graph"), "figure"),
- Input(self.id("traces"), "data"),
+ Output(self.id("zone"), "data"),
+ Output(self.id("table"), "children"),
+ Input(self.id("ph_bs"), "data"),
+ Input(self.id("ph_dos"), "data"),
+ Input(self.id("ph-bsdos-graph"), "clickData"),
)
- def update_graph(traces):
- if traces == "error":
- msg_body = MessageBody(
- dcc.Markdown(
- "Band structure and density of states not available for this selection."
- )
- )
- return (MessageContainer([msg_body], kind="warning"),)
-
- if traces is None:
- raise PreventUpdate
-
- bs, dos = self._get_ph_bs_dos(self.initial_data["default"])
+ def update_graph(bs, dos, nclick):
+ if isinstance(bs, dict):
+ bs = PhononBandStructureSymmLine.from_dict(bs)
+ if isinstance(dos, dict):
+ dos = CompletePhononDos.from_dict(dos)
figure = self.get_figure(bs, dos)
- return dcc.Graph(
- figure=figure, config={"displayModeBar": False}, responsive=True
- )
-
- @app.callback(
- Output(self.id("label-select"), "value"),
- Output(self.id("label-container"), "style"),
- Input(self.id("mpid"), "data"),
- Input(self.id("path-convention"), "value"),
- )
- def update_label_select(mpid, path_convention):
- if not mpid:
- raise PreventUpdate
- label_value = path_convention
- label_style = {"maxWidth": "200"}
-
- return label_value, label_style
-
- @app.callback(
- Output(self.id("dos-select"), "options"),
- Output(self.id("path-convention"), "options"),
- Output(self.id("path-container"), "style"),
- Input(self.id("elements"), "data"),
- Input(self.id("mpid"), "data"),
- )
- def update_select(elements, mpid):
- if elements is None:
- raise PreventUpdate
- if not mpid:
- dos_options = (
- [{"label": "Element Projected", "value": "ap"}]
- + [{"label": "Orbital Projected - Total", "value": "op"}]
- + [
- {
- "label": "Orbital Projected - " + str(ele_label),
- "value": "orb" + str(ele_label),
- }
- for ele_label in elements
- ]
- )
-
- path_options = [{"label": "N/A", "value": "sc"}]
- path_style = {"maxWidth": "200", "display": "none"}
-
- return dos_options, path_options, path_style
- dos_options = (
- [{"label": "Element Projected", "value": "ap"}]
- + [{"label": "Orbital Projected - Total", "value": "op"}]
- + [
- {
- "label": "Orbital Projected - " + str(ele_label),
- "value": "orb" + str(ele_label),
- }
- for ele_label in elements
- ]
- )
- path_options = [
- {"label": "Setyawan-Curtarolo", "value": "sc"},
- {"label": "Latimer-Munro", "value": "lm"},
- {"label": "Hinuma et al.", "value": "hin"},
+ # remove marker if there is one
+ figure["data"] = [
+ t for t in figure["data"] if t.get("name") != "click-marker"
]
- path_style = {"maxWidth": "200"}
-
- return dos_options, path_options, path_style
-
- @app.callback(
- Output(self.id("traces"), "data"),
- Output(self.id("elements"), "data"),
- Input(self.id(), "data"),
- Input(self.id("path-convention"), "value"),
- Input(self.id("dos-select"), "value"),
- Input(self.id("label-select"), "value"),
- )
- def bs_dos_data(data, dos_select, label_select):
- # Obtain bands to plot over and generate traces for bs data:
- energy_window = (-6.0, 10.0)
-
- traces = []
+ x_click = nclick["points"][0]["x"] if nclick else 0
+ y_click = nclick["points"][0]["y"] if nclick else 0
+ pt = nclick["points"][0] if nclick else {}
- bsml, density_of_states = self._get_ph_bs_dos(data)
-
- if self.bandstructure_symm_line:
- bs_traces = self.get_ph_bandstructure_traces(
- bsml, freq_range=energy_window
- )
- traces.append(bs_traces)
+ qpoint, band_num = pt.get("customdata", [0, 0])
- if self.density_of_states:
- dos_traces = self.get_ph_dos_traces(
- density_of_states, freq_range=energy_window
- )
- traces.append(dos_traces)
+ figure["data"].append(
+ {
+ "type": "scatter",
+ "mode": "markers",
+ "x": [x_click],
+ "y": [y_click],
+ "marker": {
+ "color": MARKER_COLOR,
+ "size": MARKER_SIZE,
+ "symbol": MARKER_SHAPE,
+ },
+ "name": "click-marker",
+ "showlegend": False,
+ "customdata": [[qpoint, band_num]],
+ "hovertemplate": (
+ "band: %{customdata[1]}
q-point: %{customdata[0]}
"
+ ),
+ }
+ )
- # traces = [bs_traces, dos_traces, bs_data]
+ zone_scene = self.get_brillouin_zone_scene(bs)
- # TODO: not tested if this is correct way to get element list
- elements = list(map(str, density_of_states.get_element_dos()))
+ summary_dict = self._get_data_list_dict(bs, dos)
+ summary_table = get_data_list(summary_dict)
- return traces, elements
+ return figure, zone_scene.to_json(), summary_table
@app.callback(
Output(self.id("brillouin-zone"), "data"),
@@ -711,8 +887,78 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
# TODO: figure out what to return (CSS?) to highlight BZ edge/point
return
- # TODO: figure out what to return (CSS?) to highlight BZ edge/point
- return
+ @app.callback(
+ Output(self.id("crystal-animation"), "data"),
+ Input(self.id("ph-bsdos-graph"), "clickData"),
+ Input(self.id("ph_bs"), "data"),
+ Input(self.id("supercell-controls-btn"), "n_clicks"),
+ Input(self.get_kwarg_id("magnitude"), "value"),
+ State(self.get_kwarg_id("scale-x"), "value"),
+ State(self.get_kwarg_id("scale-y"), "value"),
+ State(self.get_kwarg_id("scale-z"), "value"),
+ # prevent_initial_call=True
+ )
+ def update_crystal_animation(
+ cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z
+ ):
+ # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields,
+ # ensuring updates occur only after the `supercell-controls-btn`` is clicked.
+
+ if not bs:
+ raise PreventUpdate
+
+ # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values.
+ # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value,
+ # this approach provides better clarity and readability.
+ kwargs = self.reconstruct_kwargs_from_state()
+ magnitude_fraction = kwargs.get("magnitude")
+ scale_x = kwargs.get("scale-x")
+ scale_y = kwargs.get("scale-y")
+ scale_z = kwargs.get("scale-z")
+
+ if isinstance(bs, dict):
+ bs = PhononBandStructureSymmLine.from_dict(bs)
+
+ struct = bs.structure
+ total_repeat_cell_cnt = 1
+ # update structure if the controls got triggered
+ if sueprcell_update:
+ total_repeat_cell_cnt = scale_x * scale_y * scale_z
+
+ # create supercell
+ trans = SupercellTransformation(
+ ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z))
+ )
+ struct = trans.apply_transformation(struct)
+
+ struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN())
+ scene = struc_graph.get_scene(
+ draw_image_atoms=False,
+ bonded_sites_outside_unit_cell=False,
+ site_get_scene_kwargs={"retain_atom_idx": True},
+ )
+ json_data = scene.to_json()
+
+ qpoint = 0
+ band_num = 0
+
+ if cd and cd.get("points"):
+ pt = cd["points"][0]
+ qpoint, band_num = pt.get("customdata", [0, 0])
+
+ # magnitude
+ magnitude = (
+ MAX_MAGNITUDE - MIN_MAGNITUDE
+ ) * magnitude_fraction + MIN_MAGNITUDE
+
+ return PhononBandstructureAndDosComponent._get_eigendisplacement(
+ ph_bs=bs,
+ json_data=json_data,
+ band=band_num,
+ qpoint=qpoint,
+ total_repeat_cell_cnt=total_repeat_cell_cnt,
+ magnitude=magnitude,
+ )
class PhononBandstructureAndDosPanelComponent(PanelComponent):