diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 296a2947..5ef76d90 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1,39 +1,41 @@ from __future__ import annotations import itertools +from copy import deepcopy from typing import TYPE_CHECKING, Any import numpy as np import plotly.graph_objects as go from dash import dcc, html -from dash.dependencies import Component, Input, Output +from dash.dependencies import Component, Input, Output, State from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitScene +from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene + +# crystal animation algo +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis.local_env import CrystalNN from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.phonon.plotter import PhononBSPlotter +from pymatgen.transformations.standard_transformations import SupercellTransformation from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import ( - Column, - Columns, - Label, - MessageBody, - MessageContainer, - get_data_list, -) +from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.dos import CompleteDos -# Author: Jason Munro, Janosh Riebesell -# Contact: jmunro@lbl.gov, janosh@lbl.gov - +DISPLACE_COEF = [0, 1, 0, -1, 0] +MARKER_COLOR = "red" +MARKER_SIZE = 12 +MARKER_SHAPE = "x" +MAX_MAGNITUDE = 300 +MIN_MAGNITUDE = 0 # TODOs: # - look for additional projection methods in phonon DOS (currently only atom @@ -64,26 +66,32 @@ def __init__( **kwargs, ) + bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + self.create_store("bs-store", bs) + self.create_store("bs", None) + self.create_store("dos", None) + @property def _sub_layouts(self) -> dict[str, Component]: # defaults state = {"label-select": "sc", "dos-select": "ap"} - bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - fig = PhononBandstructureAndDosComponent.get_figure(bs, dos) + fig = PhononBandstructureAndDosComponent.get_figure(None, None) # Main plot graph = dcc.Graph( figure=fig, config={"displayModeBar": False}, - responsive=True, + responsive=False, id=self.id("ph-bsdos-graph"), ) # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(bs) - zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px") + zone_scene = self.get_brillouin_zone_scene(None) + zone = CrystalToolkitScene( + data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") + ) # Hide by default if not loaded by mpid, switching between k-paths # on-the-fly only supported for bandstructures retrieved from MP @@ -105,9 +113,11 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style={"width": "200px"} - if show_path_options - else {"maxWidth": "200", "display": "none"}, + style=( + {"width": "200px"} + if show_path_options + else {"maxWidth": "200", "display": "none"} + ), id=self.id("path-container"), ) @@ -122,9 +132,11 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style={"width": "200px"} - if show_path_options - else {"width": "200px", "display": "none"}, + style=( + {"width": "200px"} + if show_path_options + else {"width": "200px", "display": "none"} + ), id=self.id("label-container"), ) @@ -138,9 +150,82 @@ def _sub_layouts(self) -> dict[str, Component]: style={"width": "200px"}, ) - summary_dict = self._get_data_list_dict(bs, dos) + summary_dict = self._get_data_list_dict(None, None) summary_table = get_data_list(summary_dict) + # crystal visualization + + tip = html.H5( + "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!", + ) + + crystal_animation = html.Div( + CrystalToolkitAnimationScene( + data={}, + sceneSize="500px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.2}, + axisView="SW", + showControls=False, # disable download for now + ), + style={"width": "60%"}, + ) + + crystal_animation_controls = html.Div( + [ + html.Br(), + html.Div(tip, style={"textAlign": "center"}), + html.Br(), + html.H5("Control Panel", style={"textAlign": "center"}), + html.H6("Supercell modification"), + html.Br(), + html.Div( + [ + self.get_numerical_input( + kwarg_label="scale-x", + default=1, + is_int=True, + label="x", + min=1, + style={"width": "5rem"}, + ), + self.get_numerical_input( + kwarg_label="scale-y", + default=1, + is_int=True, + label="y", + min=1, + style={"width": "5rem"}, + ), + self.get_numerical_input( + kwarg_label="scale-z", + default=1, + is_int=True, + label="z", + min=1, + style={"width": "5rem"}, + ), + html.Button( + "Update", + id=self.id("supercell-controls-btn"), + style={"height": "40px"}, + ), + ], + style={"display": "flex"}, + ), + html.Br(), + html.Div( + self.get_slider_input( + kwarg_label="magnitude", + default=0.5, + step=0.01, + domain=[0, 1], + label="Vibration magnitude", + ) + ), + ], + ) + return { "graph": graph, "convention": convention, @@ -148,10 +233,31 @@ def _sub_layouts(self) -> dict[str, Component]: "label-select": label_select, "zone": zone, "table": summary_table, + "crystal-animation": crystal_animation, + "tip": tip, + "crystal-animation-controls": crystal_animation_controls, } + def _get_animation_panel(self): + sub_layouts = self._sub_layouts + return Columns( + [ + Column( + [ + Columns( + [ + sub_layouts["crystal-animation"], + sub_layouts["crystal-animation-controls"], + ] + ) + ] + ), + ] + ) + def layout(self) -> html.Div: sub_layouts = self._sub_layouts + crystal_animation = self._get_animation_panel() graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -166,11 +272,143 @@ def layout(self) -> html.Div: ) brillouin_zone = Columns( [ - Column([Label("Summary"), sub_layouts["table"]]), + Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), Column([Label("Brillouin Zone"), sub_layouts["zone"]]), ] ) - return html.Div([graph, controls, brillouin_zone]) + + return html.Div([graph, crystal_animation, controls, brillouin_zone]) + + @staticmethod + def _get_eigendisplacement( + ph_bs: BandStructureSymmLine, + json_data: dict, + band: int = 0, + qpoint: int = 0, + precision: int = 15, + magnitude: int = MAX_MAGNITUDE / 2, + total_repeat_cell_cnt: int = 1, + ) -> dict: + if not ph_bs or not json_data: + return {} + + assert json_data["contents"][0]["name"] == "atoms" + assert json_data["contents"][1]["name"] == "bonds" + rdata = deepcopy(json_data) + + def calc_max_displacement(idx: int) -> list: + """ + Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement. + + Parameters: + idx (int): The atom index. + + Returns: + list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement] + + This function extracts the real component of the atom's eigendisplacement, + scales it by the specified magnitude, and returns the resulting vector. + """ + + # get the atom index + assert total_repeat_cell_cnt != 0 + + modified_idx = ( + int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx + ) + + return [ + round(complex(vec).real * magnitude, precision) + for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx] + ] + + def calc_animation_step(max_displacement: list, coef: int) -> list: + """ + Calculate the displacement for an animation frame based on the given coefficient. + + Parameters: + max_displacement (list): A list of maximum displacements along each axis, + formatted as [x_max_displacement, y_max_displacement, z_max_displacement]. + coef (int): A coefficient indicating the motion direction. + - 0: no movement + - 1: forward movement + - -1: backward movement + + Returns: + list: The displacement vector [x_displacement, y_displacement, z_displacement]. + + This function generates oscillatory motion by scaling the maximum displacement + with the provided coefficient. + """ + return [round(coef * md, precision) for md in max_displacement] + + # Compute per-frame atomic motion. + # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates. + contents0 = json_data["contents"][0]["contents"] + for cidx, content in enumerate(contents0): + max_displacement = calc_max_displacement(content["_meta"][0]) + rcontent = rdata["contents"][0]["contents"][cidx] + # put animation frame to the given atom index + rcontent["animate"] = [ + calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF + ] + rcontent["keyframes"] = list(range(len(DISPLACE_COEF))) + rcontent["animateType"] = "displacement" + # Compute per-frame bonding motion. + # Explanation: + # Each bond connects two atoms, `u` and `v`, represented as (u)----(v) + # To model the bond motion, it is divided into two segments: + # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v) + # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid). + # For each cylinder, displacements are assigned to the endpoints — for example, + # the (u)--(mid) cylinder uses: + # [ + # [u_x_displacement, u_y_displacement, u_z_displacement], + # [mid_x_displacement, mid_y_displacement, mid_z_displacement] + # ]. + contents1 = json_data["contents"][1]["contents"] + + for cidx, content in enumerate(contents1): + bond_animation = [] + assert len(content["_meta"]) == len(content["positionPairs"]) + + for atom_idx_pair in content["_meta"]: + max_displacements = list( + map(calc_max_displacement, atom_idx_pair) + ) # max displacement for u and v + + u_to_middle_bond_animation = [] + + for coef in DISPLACE_COEF: + # Calculate the midpoint displacement between atom u and v for each animation frame. + u_displacement, v_displacement = [ + np.array(calc_animation_step(max_displacement, coef)) + for max_displacement in max_displacements + ] + middle_end_displacement = np.add(u_displacement, v_displacement) / 2 + + u_to_middle_bond_animation.append( + [ + u_displacement, # u atom displacement + [ + round(dis, precision) for dis in middle_end_displacement + ], # middle point displacement + ] + ) + + bond_animation.append(u_to_middle_bond_animation) + + rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation + rdata["contents"][1]["contents"][cidx]["keyframes"] = list( + range(len(DISPLACE_COEF)) + ) + rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" + + # remove unused sense + for i in range(2, 4): + rdata["contents"][i]["visible"] = False + + return rdata @staticmethod def _get_ph_bs_dos( @@ -303,6 +541,7 @@ def get_ph_bandstructure_traces(bs, freq_range): "line": {"color": "#1f77b4"}, "hoverinfo": "skip", "name": "Total", + "customdata": [[di, band_num] for di in range(len(x_dat))], "hovertemplate": "%{y:.2f} THz", "showlegend": False, "xaxis": "x", @@ -348,6 +587,9 @@ def get_ph_bandstructure_traces(bs, freq_range): def _get_data_list_dict( bs: PhononBandStructureSymmLine, dos: CompletePhononDos ) -> dict[str, str | bool | int]: + if (not bs) and (not dos): + return {} + bs_minpoint, bs_min_freq = bs.min_freq() min_freq_report = ( f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" @@ -373,7 +615,7 @@ def _get_data_list_dict( target="blank", ), ] - ): "Yes" if bs.has_nac else "No", + ): ("Yes" if bs.has_nac else "No"), "Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No", "Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No", "Min frequency": min_freq_report, @@ -443,14 +685,9 @@ def get_figure( ph_dos: CompletePhononDos | None = None, freq_range: tuple[float | None, float | None] = (None, None), ) -> go.Figure: - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - if (not ph_dos) and (not ph_bs): empty_plot_style = { + "height": 500, "xaxis": {"visible": False}, "yaxis": {"visible": False}, "paper_bgcolor": "rgba(0,0,0,0)", @@ -459,6 +696,12 @@ def get_figure( return go.Figure(layout=empty_plot_style) + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + if ph_bs: ( bs_traces, @@ -555,7 +798,7 @@ def get_figure( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(230,230,230,230)", margin=dict(l=60, b=50, t=50, pad=0, r=30), - # clickmode="event+select" + clickmode="event+select", ) figure = {"data": bs_traces + dos_traces, "layout": layout} @@ -580,124 +823,57 @@ def get_figure( def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), - Input(self.id("traces"), "data"), + Output(self.id("zone"), "data"), + Output(self.id("table"), "children"), + Input(self.id("ph_bs"), "data"), + Input(self.id("ph_dos"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), ) - def update_graph(traces): - if traces == "error": - msg_body = MessageBody( - dcc.Markdown( - "Band structure and density of states not available for this selection." - ) - ) - return (MessageContainer([msg_body], kind="warning"),) - - if traces is None: - raise PreventUpdate - - bs, dos = self._get_ph_bs_dos(self.initial_data["default"]) + def update_graph(bs, dos, nclick): + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + if isinstance(dos, dict): + dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - return dcc.Graph( - figure=figure, config={"displayModeBar": False}, responsive=True - ) - - @app.callback( - Output(self.id("label-select"), "value"), - Output(self.id("label-container"), "style"), - Input(self.id("mpid"), "data"), - Input(self.id("path-convention"), "value"), - ) - def update_label_select(mpid, path_convention): - if not mpid: - raise PreventUpdate - label_value = path_convention - label_style = {"maxWidth": "200"} - - return label_value, label_style - - @app.callback( - Output(self.id("dos-select"), "options"), - Output(self.id("path-convention"), "options"), - Output(self.id("path-container"), "style"), - Input(self.id("elements"), "data"), - Input(self.id("mpid"), "data"), - ) - def update_select(elements, mpid): - if elements is None: - raise PreventUpdate - if not mpid: - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [{"label": "N/A", "value": "sc"}] - path_style = {"maxWidth": "200", "display": "none"} - - return dos_options, path_options, path_style - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - path_options = [ - {"label": "Setyawan-Curtarolo", "value": "sc"}, - {"label": "Latimer-Munro", "value": "lm"}, - {"label": "Hinuma et al.", "value": "hin"}, + # remove marker if there is one + figure["data"] = [ + t for t in figure["data"] if t.get("name") != "click-marker" ] - path_style = {"maxWidth": "200"} - - return dos_options, path_options, path_style - - @app.callback( - Output(self.id("traces"), "data"), - Output(self.id("elements"), "data"), - Input(self.id(), "data"), - Input(self.id("path-convention"), "value"), - Input(self.id("dos-select"), "value"), - Input(self.id("label-select"), "value"), - ) - def bs_dos_data(data, dos_select, label_select): - # Obtain bands to plot over and generate traces for bs data: - energy_window = (-6.0, 10.0) - - traces = [] + x_click = nclick["points"][0]["x"] if nclick else 0 + y_click = nclick["points"][0]["y"] if nclick else 0 + pt = nclick["points"][0] if nclick else {} - bsml, density_of_states = self._get_ph_bs_dos(data) - - if self.bandstructure_symm_line: - bs_traces = self.get_ph_bandstructure_traces( - bsml, freq_range=energy_window - ) - traces.append(bs_traces) + qpoint, band_num = pt.get("customdata", [0, 0]) - if self.density_of_states: - dos_traces = self.get_ph_dos_traces( - density_of_states, freq_range=energy_window - ) - traces.append(dos_traces) + figure["data"].append( + { + "type": "scatter", + "mode": "markers", + "x": [x_click], + "y": [y_click], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[qpoint, band_num]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ) - # traces = [bs_traces, dos_traces, bs_data] + zone_scene = self.get_brillouin_zone_scene(bs) - # TODO: not tested if this is correct way to get element list - elements = list(map(str, density_of_states.get_element_dos())) + summary_dict = self._get_data_list_dict(bs, dos) + summary_table = get_data_list(summary_dict) - return traces, elements + return figure, zone_scene.to_json(), summary_table @app.callback( Output(self.id("brillouin-zone"), "data"), @@ -711,8 +887,78 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): # TODO: figure out what to return (CSS?) to highlight BZ edge/point return - # TODO: figure out what to return (CSS?) to highlight BZ edge/point - return + @app.callback( + Output(self.id("crystal-animation"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("ph_bs"), "data"), + Input(self.id("supercell-controls-btn"), "n_clicks"), + Input(self.get_kwarg_id("magnitude"), "value"), + State(self.get_kwarg_id("scale-x"), "value"), + State(self.get_kwarg_id("scale-y"), "value"), + State(self.get_kwarg_id("scale-z"), "value"), + # prevent_initial_call=True + ) + def update_crystal_animation( + cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z + ): + # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields, + # ensuring updates occur only after the `supercell-controls-btn`` is clicked. + + if not bs: + raise PreventUpdate + + # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values. + # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value, + # this approach provides better clarity and readability. + kwargs = self.reconstruct_kwargs_from_state() + magnitude_fraction = kwargs.get("magnitude") + scale_x = kwargs.get("scale-x") + scale_y = kwargs.get("scale-y") + scale_z = kwargs.get("scale-z") + + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + + struct = bs.structure + total_repeat_cell_cnt = 1 + # update structure if the controls got triggered + if sueprcell_update: + total_repeat_cell_cnt = scale_x * scale_y * scale_z + + # create supercell + trans = SupercellTransformation( + ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) + ) + struct = trans.apply_transformation(struct) + + struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) + scene = struc_graph.get_scene( + draw_image_atoms=False, + bonded_sites_outside_unit_cell=False, + site_get_scene_kwargs={"retain_atom_idx": True}, + ) + json_data = scene.to_json() + + qpoint = 0 + band_num = 0 + + if cd and cd.get("points"): + pt = cd["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + # magnitude + magnitude = ( + MAX_MAGNITUDE - MIN_MAGNITUDE + ) * magnitude_fraction + MIN_MAGNITUDE + + return PhononBandstructureAndDosComponent._get_eigendisplacement( + ph_bs=bs, + json_data=json_data, + band=band_num, + qpoint=qpoint, + total_repeat_cell_cnt=total_repeat_cell_cnt, + magnitude=magnitude, + ) class PhononBandstructureAndDosPanelComponent(PanelComponent):