From 56497f21885ce9b8111aa0ad69be841ac6fa878d Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 15:18:31 +0800 Subject: [PATCH 01/24] add atom index to _meta --- crystal_toolkit/core/scene.py | 21 ++++++++++++------- crystal_toolkit/renderables/site.py | 5 +++++ crystal_toolkit/renderables/structuregraph.py | 20 ++++++++++++++++++ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 392c4d49..01b6f7f1 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -67,13 +67,13 @@ def __add__(self, other): lattice=self.lattice, _meta={self.name: self._meta, other.name: other._meta}, ) - - def _repr_mimebundle_(self, include=None, exclude=None): - """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" - return { - "application/vnd.mp.ctk+json": self.to_json(), - "text/plain": repr(self), - } + + # def _repr_mimebundle_(self, include=None, exclude=None): + # """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" + # return { + # "application/vnd.mp.ctk+json": self.to_json(), + # "text/plain": repr(self), + # } def to_json(self): """Convert a Scene into JSON. It will implicitly assume all None values means that attribute @@ -149,7 +149,6 @@ def merge_primitives(primitives): """ mergeable = defaultdict(list) remainder = [] - for primitive in primitives: if isinstance(primitive, Scene): primitive.contents = Scene.merge_primitives(primitive.contents) @@ -214,6 +213,7 @@ def merge(cls, sphere_list): visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, + _meta=sphere_list[0]._meta, ) @@ -320,6 +320,10 @@ def merge(cls, cylinder_list): chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) + new_meta_list = list( + chain.from_iterable([[cylinder._meta] for cylinder in cylinder_list]) + ) + return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, @@ -327,6 +331,7 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, + _meta=new_meta_list ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index ef43be57..d02aa95e 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,6 +135,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, + _meta=[site_idx] ) atoms.append(sphere) @@ -207,6 +208,7 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -218,6 +220,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) @@ -228,6 +231,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -251,6 +255,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index 6c8077c2..5c7abb25 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -197,6 +197,8 @@ def get_weight_color(weight): explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, bond_radius=bond_radius, + site_idx=idx, + show_atom_idx=True, **(site_get_scene_kwargs or {}), ) @@ -217,6 +219,23 @@ def get_weight_color(weight): primitives["unit_cell"].append(self.structure.lattice.get_scene()) + """ + ss = Scene( + name="StructureGraph", + origin=origin, + contents=[ + Scene(name=key, contents=val, origin=origin) + for key, val in primitives.items() + ], + ) + print(id(ss)) + print(ss.contents[1]) + print(ss.contents[1].contents[0]._meta) + print(ss) + + return(ss) + """ + # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -226,6 +245,7 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) + StructureGraph._get_sites_to_draw = _get_sites_to_draw From a0b9888a976070112f363500cec6a119783b293f Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 16:15:26 +0800 Subject: [PATCH 02/24] Revert "add atom index to _meta" This reverts commit 56497f21885ce9b8111aa0ad69be841ac6fa878d. --- crystal_toolkit/core/scene.py | 21 +++++++------------ crystal_toolkit/renderables/site.py | 5 ----- crystal_toolkit/renderables/structuregraph.py | 20 ------------------ 3 files changed, 8 insertions(+), 38 deletions(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 01b6f7f1..392c4d49 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -67,13 +67,13 @@ def __add__(self, other): lattice=self.lattice, _meta={self.name: self._meta, other.name: other._meta}, ) - - # def _repr_mimebundle_(self, include=None, exclude=None): - # """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" - # return { - # "application/vnd.mp.ctk+json": self.to_json(), - # "text/plain": repr(self), - # } + + def _repr_mimebundle_(self, include=None, exclude=None): + """Render Scenes using crystaltoolkit-extension for Jupyter Lab.""" + return { + "application/vnd.mp.ctk+json": self.to_json(), + "text/plain": repr(self), + } def to_json(self): """Convert a Scene into JSON. It will implicitly assume all None values means that attribute @@ -149,6 +149,7 @@ def merge_primitives(primitives): """ mergeable = defaultdict(list) remainder = [] + for primitive in primitives: if isinstance(primitive, Scene): primitive.contents = Scene.merge_primitives(primitive.contents) @@ -213,7 +214,6 @@ def merge(cls, sphere_list): visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, - _meta=sphere_list[0]._meta, ) @@ -320,10 +320,6 @@ def merge(cls, cylinder_list): chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) - new_meta_list = list( - chain.from_iterable([[cylinder._meta] for cylinder in cylinder_list]) - ) - return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, @@ -331,7 +327,6 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, - _meta=new_meta_list ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index d02aa95e..ef43be57 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,7 +135,6 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx] ) atoms.append(sphere) @@ -208,7 +207,6 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -220,7 +218,6 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) @@ -231,7 +228,6 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -255,7 +251,6 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index 5c7abb25..6c8077c2 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -197,8 +197,6 @@ def get_weight_color(weight): explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, bond_radius=bond_radius, - site_idx=idx, - show_atom_idx=True, **(site_get_scene_kwargs or {}), ) @@ -219,23 +217,6 @@ def get_weight_color(weight): primitives["unit_cell"].append(self.structure.lattice.get_scene()) - """ - ss = Scene( - name="StructureGraph", - origin=origin, - contents=[ - Scene(name=key, contents=val, origin=origin) - for key, val in primitives.items() - ], - ) - print(id(ss)) - print(ss.contents[1]) - print(ss.contents[1].contents[0]._meta) - print(ss) - - return(ss) - """ - # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -245,7 +226,6 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) - StructureGraph._get_sites_to_draw = _get_sites_to_draw From c024d477905138e19621536996a2350fe1d2d341 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 16:41:24 +0800 Subject: [PATCH 03/24] add index to _meta --- crystal_toolkit/core/scene.py | 7 +++++++ crystal_toolkit/renderables/site.py | 5 +++++ crystal_toolkit/renderables/structuregraph.py | 5 ++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 392c4d49..54ee3982 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -214,6 +214,7 @@ def merge(cls, sphere_list): visible=sphere_list[0].visible, clickable=sphere_list[0].clickable, tooltip=sphere_list[0].tooltip, + _meta=sphere_list[0]._meta, ) @@ -272,6 +273,7 @@ def merge(cls, ellipsoid_list): ] ) ) + return cls( positions=new_positions, @@ -320,6 +322,10 @@ def merge(cls, cylinder_list): chain.from_iterable([cylinder.positionPairs for cylinder in cylinder_list]) ) + new_meta_list = list( + chain.from_iterable([[cylinder._meta] for cylinder in cylinder_list]) + ) + return cls( positionPairs=new_positionPairs, color=cylinder_list[0].color, @@ -327,6 +333,7 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, + _meta=new_meta_list ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index ef43be57..d02aa95e 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,6 +135,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, + _meta=[site_idx] ) atoms.append(sphere) @@ -207,6 +208,7 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -218,6 +220,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) @@ -228,6 +231,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -251,6 +255,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, + _meta=[site_idx, connected_site.index] ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index 6c8077c2..b6315bc4 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -197,6 +197,8 @@ def get_weight_color(weight): explicitly_calculate_polyhedra_hull=explicitly_calculate_polyhedra_hull, legend=legend, bond_radius=bond_radius, + site_idx=idx, + show_atom_idx=True, **(site_get_scene_kwargs or {}), ) @@ -216,7 +218,7 @@ def get_weight_color(weight): primitives["atoms"] = atoms_scenes primitives["unit_cell"].append(self.structure.lattice.get_scene()) - + # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -226,6 +228,7 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) + StructureGraph._get_sites_to_draw = _get_sites_to_draw From 52f15bfcb453e5631a41fb9e0cefe642c6c5aac1 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 17:00:33 +0800 Subject: [PATCH 04/24] remove empty line --- crystal_toolkit/renderables/structuregraph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crystal_toolkit/renderables/structuregraph.py b/crystal_toolkit/renderables/structuregraph.py index b6315bc4..163c7230 100644 --- a/crystal_toolkit/renderables/structuregraph.py +++ b/crystal_toolkit/renderables/structuregraph.py @@ -218,7 +218,7 @@ def get_weight_color(weight): primitives["atoms"] = atoms_scenes primitives["unit_cell"].append(self.structure.lattice.get_scene()) - + # why primitives comprehension? just make explicit! more readable return Scene( name="StructureGraph", @@ -228,7 +228,6 @@ def get_weight_color(weight): for key, val in primitives.items() ], ) - StructureGraph._get_sites_to_draw = _get_sites_to_draw From 3d6266fecbb6f29f06d5105894bfe7585df2b726 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:26:33 +0000 Subject: [PATCH 05/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- crystal_toolkit/core/scene.py | 3 +-- crystal_toolkit/renderables/site.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index 54ee3982..ad2d0621 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -273,7 +273,6 @@ def merge(cls, ellipsoid_list): ] ) ) - return cls( positions=new_positions, @@ -333,7 +332,7 @@ def merge(cls, cylinder_list): visible=cylinder_list[0].visible, clickable=cylinder_list[0].clickable, tooltip=cylinder_list[0].tooltip, - _meta=new_meta_list + _meta=new_meta_list, ) @property diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index d02aa95e..18f42128 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -135,7 +135,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx] + _meta=[site_idx], ) atoms.append(sphere) @@ -208,7 +208,7 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -220,7 +220,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) bonds.append(cylinder) @@ -231,7 +231,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -255,7 +255,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index] + _meta=[site_idx, connected_site.index], ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) From d16ac38036e7e60c8df08692d5f9c55d9a3689d3 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 18:09:47 +0800 Subject: [PATCH 06/24] ruff format --- crystal_toolkit/renderables/site.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index 18f42128..5ce92012 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -47,6 +47,7 @@ def get_site_scene( visualize_bond_orders: bool = False, magmom_scale: float = 1.0, legend: Legend | None = None, + retain_atom_idx: bool = False, ) -> Scene: """Get a Scene object for a Site. @@ -70,6 +71,7 @@ def get_site_scene( visualize_bond_orders (bool, optional): Defaults to False. magmom_scale (float, optional): Defaults to 1.0. legend (Legend | None, optional): Defaults to None. + retain_atom_idx (bool, optional): Defaults to False. Returns: Scene: The scene object containing atoms, bonds, polyhedra, magmoms. From 0e1c9249e81c9d51b094866c320e110f5ba11e99 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 7 Aug 2025 18:19:08 +0800 Subject: [PATCH 07/24] add retain_atom_idx --- crystal_toolkit/renderables/site.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index 5ce92012..bab81bfd 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -137,7 +137,7 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx], + _meta=[site_idx] if retain_atom_idx else None, ) atoms.append(sphere) @@ -210,7 +210,9 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] + if retain_atom_idx + else None, ) ) trans_vector = trans_vector + 0.25 * max_radius @@ -222,7 +224,9 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] + if retain_atom_idx + else None, ) bonds.append(cylinder) @@ -233,7 +237,7 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] if retain_atom_idx else None, ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -257,7 +261,7 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index], + _meta=[site_idx, connected_site.index] if retain_atom_idx else None, ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) From a16c9c5d2947295425618f1419485cc0271bd0cb Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Tue, 9 Sep 2025 12:46:12 -0700 Subject: [PATCH 08/24] add animation component --- crystal_toolkit/components/phonon.py | 384 +++++++++++++++++---------- 1 file changed, 244 insertions(+), 140 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 296a2947..f0808a0a 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -8,7 +8,11 @@ from dash import dcc, html from dash.dependencies import Component, Input, Output from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitScene +from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene + +# crystal animation algo +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis.local_env import CrystalNN from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos @@ -17,14 +21,7 @@ from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import ( - Column, - Columns, - Label, - MessageBody, - MessageContainer, - get_data_list, -) +from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: @@ -64,26 +61,32 @@ def __init__( **kwargs, ) + bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + self.create_store("bs-store", bs) + self.create_store("bs", None) + self.create_store("dos", None) + @property def _sub_layouts(self) -> dict[str, Component]: # defaults state = {"label-select": "sc", "dos-select": "ap"} - bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - fig = PhononBandstructureAndDosComponent.get_figure(bs, dos) + fig = PhononBandstructureAndDosComponent.get_figure(None, None) # Main plot graph = dcc.Graph( figure=fig, config={"displayModeBar": False}, - responsive=True, + responsive=False, id=self.id("ph-bsdos-graph"), ) # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(bs) - zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px") + zone_scene = self.get_brillouin_zone_scene(None) + zone = CrystalToolkitScene( + data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") + ) # Hide by default if not loaded by mpid, switching between k-paths # on-the-fly only supported for bandstructures retrieved from MP @@ -138,9 +141,29 @@ def _sub_layouts(self) -> dict[str, Component]: style={"width": "200px"}, ) - summary_dict = self._get_data_list_dict(bs, dos) + summary_dict = self._get_data_list_dict(None, None) summary_table = get_data_list(summary_dict) + # crystal visualization + + tip = html.P( + "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", + id=self.id("crystal-tip"), + style={ + "margin": "0 0 12px", + "fontSize": "16px", + "color": "#555", + "textAlign": "center", + }, + ) + + crystal_animation = CrystalToolkitAnimationScene( + data={}, + sceneSize="200px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.5}, + ) + return { "graph": graph, "convention": convention, @@ -148,10 +171,15 @@ def _sub_layouts(self) -> dict[str, Component]: "label-select": label_select, "zone": zone, "table": summary_table, + "crystal-animation": crystal_animation, + "tip": tip, } def layout(self) -> html.Div: sub_layouts = self._sub_layouts + crystal_animation = Columns( + [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + ) graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -166,11 +194,147 @@ def layout(self) -> html.Div: ) brillouin_zone = Columns( [ - Column([Label("Summary"), sub_layouts["table"]]), + Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), Column([Label("Brillouin Zone"), sub_layouts["zone"]]), ] ) - return html.Div([graph, controls, brillouin_zone]) + + return html.Div([graph, crystal_animation, controls, brillouin_zone]) + + @staticmethod + def _get_eigendisplacement( + ph_bs: BandStructureSymmLine, + json_data: dict, + band: int = 0, + qpoint: int = 0, + precision: int = 15, + magnitude: int = 15, + ) -> dict: + if not ph_bs: + return {} + + # get displacement + min_bond_length = float("inf") + for content_idx in range(len(json_data["contents"][1]["contents"])): + for pair_idx in range( + len(json_data["contents"][1]["contents"][content_idx]["_meta"]) + ): + u, v = json_data["contents"][1]["contents"][content_idx][ + "positionPairs" + ][pair_idx] + # Convert to numpy arrays + u = np.array(u) + v = np.array(v) + length = np.linalg.norm(v - u) + min_bond_length = min(min_bond_length, length) + + # atom animate + assert json_data["contents"][0]["name"] == "atoms" + for content_idx in range(len(json_data["contents"][0]["contents"])): + atom_idx = json_data["contents"][0]["contents"][content_idx]["_meta"][0] + + raw_displacement = ph_bs.eigendisplacements[band][qpoint][atom_idx] + + displacement = [complex(vec).real * magnitude for vec in raw_displacement] + + position_animation = [] + for displace_coef in [0, 1, 0, -1, 0]: + displace = [ + round(displace_coef * magnitude * d, precision) + for d in displacement + ] + position_animation.append(displace) + + json_data["contents"][0]["contents"][content_idx]["animate"] = ( + position_animation + ) + json_data["contents"][0]["contents"][content_idx]["keyframes"] = [ + 0, + 1, + 2, + 3, + 4, + ] + json_data["contents"][0]["contents"][content_idx]["animateType"] = ( + "displacement" + ) + + # bond animate + assert json_data["contents"][1]["name"] == "bonds" + for content_idx in range(len(json_data["contents"][1]["contents"])): + bond_animation = [] + + assert len( + json_data["contents"][1]["contents"][content_idx]["_meta"] + ) == len(json_data["contents"][1]["contents"][content_idx]["positionPairs"]) + + for pair_idx in range( + len(json_data["contents"][1]["contents"][content_idx]["_meta"]) + ): + u_idx, v_idx = json_data["contents"][1]["contents"][content_idx][ + "_meta" + ][pair_idx] + + # u + u_raw_displacement = ph_bs.eigendisplacements[band][qpoint][u_idx] + u_displacement = [ + round(complex(vec).real * magnitude, precision) + for vec in u_raw_displacement + ] + + # v + v_raw_displacement = ph_bs.eigendisplacements[band][qpoint][v_idx] + v_displacement = [ + round(complex(vec).real * magnitude, precision) + for vec in v_raw_displacement + ] + + # only draw in unit cell + u_to_middle_bond_animation = [] # u to middle + # v_to_middle_bond_animation = [] # v to middle + for displace_coef in [0, 1, 0, -1, 0]: + u_end_displacement = [ + round(displace_coef * magnitude * d, precision) + for d in u_displacement + ] + v_end_displacement = [ + round(displace_coef * magnitude * d, precision) + for d in v_displacement + ] + middle_end_displacement = ( + (np.array(u_end_displacement) + np.array(v_end_displacement)) + / 2 + ).tolist() + middle_end_displacement = [ + round(dis, precision) for dis in middle_end_displacement + ] + + u2middle_animation = [u_end_displacement, middle_end_displacement] + # v2middle_animation = [v_end_displacement, middle_end_displacement] + + u_to_middle_bond_animation.append(u2middle_animation) + # v_to_middle_bond_animation.append(v2middle_animation) + + bond_animation.append(u_to_middle_bond_animation) + json_data["contents"][1]["contents"][content_idx]["animate"] = ( + bond_animation + ) + json_data["contents"][1]["contents"][content_idx]["keyframes"] = [ + 0, + 1, + 2, + 3, + 4, + ] + json_data["contents"][1]["contents"][content_idx]["animateType"] = ( + "displacement" + ) + + # remove polyhedra manually + json_data["contents"][2]["visible"] = False + json_data["contents"][3]["visible"] = False + + return json_data @staticmethod def _get_ph_bs_dos( @@ -303,6 +467,7 @@ def get_ph_bandstructure_traces(bs, freq_range): "line": {"color": "#1f77b4"}, "hoverinfo": "skip", "name": "Total", + "customdata": [[di, band_num] for di in range(len(x_dat))], "hovertemplate": "%{y:.2f} THz", "showlegend": False, "xaxis": "x", @@ -348,6 +513,9 @@ def get_ph_bandstructure_traces(bs, freq_range): def _get_data_list_dict( bs: PhononBandStructureSymmLine, dos: CompletePhononDos ) -> dict[str, str | bool | int]: + if (not bs) and (not dos): + return {} + bs_minpoint, bs_min_freq = bs.min_freq() min_freq_report = ( f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" @@ -443,14 +611,9 @@ def get_figure( ph_dos: CompletePhononDos | None = None, freq_range: tuple[float | None, float | None] = (None, None), ) -> go.Figure: - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - if (not ph_dos) and (not ph_bs): empty_plot_style = { + "height": 500, "xaxis": {"visible": False}, "yaxis": {"visible": False}, "paper_bgcolor": "rgba(0,0,0,0)", @@ -459,6 +622,12 @@ def get_figure( return go.Figure(layout=empty_plot_style) + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + if ph_bs: ( bs_traces, @@ -555,7 +724,7 @@ def get_figure( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(230,230,230,230)", margin=dict(l=60, b=50, t=50, pad=0, r=30), - # clickmode="event+select" + clickmode="event+select", ) figure = {"data": bs_traces + dos_traces, "layout": layout} @@ -580,124 +749,25 @@ def get_figure( def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), - Input(self.id("traces"), "data"), + Output(self.id("zone"), "data"), + Output(self.id("table"), "children"), + Input(self.id("ph_bs"), "data"), + Input(self.id("ph_dos"), "data"), ) - def update_graph(traces): - if traces == "error": - msg_body = MessageBody( - dcc.Markdown( - "Band structure and density of states not available for this selection." - ) - ) - return (MessageContainer([msg_body], kind="warning"),) - - if traces is None: - raise PreventUpdate - - bs, dos = self._get_ph_bs_dos(self.initial_data["default"]) + def update_graph(bs, dos): + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + if isinstance(dos, dict): + dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - return dcc.Graph( - figure=figure, config={"displayModeBar": False}, responsive=True - ) - - @app.callback( - Output(self.id("label-select"), "value"), - Output(self.id("label-container"), "style"), - Input(self.id("mpid"), "data"), - Input(self.id("path-convention"), "value"), - ) - def update_label_select(mpid, path_convention): - if not mpid: - raise PreventUpdate - label_value = path_convention - label_style = {"maxWidth": "200"} - - return label_value, label_style - - @app.callback( - Output(self.id("dos-select"), "options"), - Output(self.id("path-convention"), "options"), - Output(self.id("path-container"), "style"), - Input(self.id("elements"), "data"), - Input(self.id("mpid"), "data"), - ) - def update_select(elements, mpid): - if elements is None: - raise PreventUpdate - if not mpid: - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [{"label": "N/A", "value": "sc"}] - path_style = {"maxWidth": "200", "display": "none"} - - return dos_options, path_options, path_style - dos_options = ( - [{"label": "Element Projected", "value": "ap"}] - + [{"label": "Orbital Projected - Total", "value": "op"}] - + [ - { - "label": "Orbital Projected - " + str(ele_label), - "value": "orb" + str(ele_label), - } - for ele_label in elements - ] - ) - - path_options = [ - {"label": "Setyawan-Curtarolo", "value": "sc"}, - {"label": "Latimer-Munro", "value": "lm"}, - {"label": "Hinuma et al.", "value": "hin"}, - ] - path_style = {"maxWidth": "200"} + zone_scene = self.get_brillouin_zone_scene(bs) - return dos_options, path_options, path_style + summary_dict = self._get_data_list_dict(bs, dos) + summary_table = get_data_list(summary_dict) - @app.callback( - Output(self.id("traces"), "data"), - Output(self.id("elements"), "data"), - Input(self.id(), "data"), - Input(self.id("path-convention"), "value"), - Input(self.id("dos-select"), "value"), - Input(self.id("label-select"), "value"), - ) - def bs_dos_data(data, dos_select, label_select): - # Obtain bands to plot over and generate traces for bs data: - energy_window = (-6.0, 10.0) - - traces = [] - - bsml, density_of_states = self._get_ph_bs_dos(data) - - if self.bandstructure_symm_line: - bs_traces = self.get_ph_bandstructure_traces( - bsml, freq_range=energy_window - ) - traces.append(bs_traces) - - if self.density_of_states: - dos_traces = self.get_ph_dos_traces( - density_of_states, freq_range=energy_window - ) - traces.append(dos_traces) - - # traces = [bs_traces, dos_traces, bs_data] - - # TODO: not tested if this is correct way to get element list - elements = list(map(str, density_of_states.get_element_dos())) - - return traces, elements + return figure, zone_scene.to_json(), summary_table @app.callback( Output(self.id("brillouin-zone"), "data"), @@ -711,8 +781,42 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): # TODO: figure out what to return (CSS?) to highlight BZ edge/point return - # TODO: figure out what to return (CSS?) to highlight BZ edge/point - return + @app.callback( + Output(self.id("crystal-animation"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("ph_bs"), "data"), + # prevent_initial_call=True + ) + def update_crystal_animation(cd, bs): + if not bs: + raise PreventUpdate + + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + + struc_graph = StructureGraph.from_local_env_strategy( + bs.structure, CrystalNN() + ) + scene = struc_graph.get_scene( + draw_image_atoms=False, + bonded_sites_outside_unit_cell=False, + site_get_scene_kwargs={"retain_atom_idx": True}, + ) + json_data = scene.to_json() + + qpoint = 0 + band_num = 0 + + if cd and cd.get("points"): + pt = cd["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + return PhononBandstructureAndDosComponent._get_eigendisplacement( + ph_bs=bs, + json_data=json_data, + band=band_num, + qpoint=qpoint, + ) class PhononBandstructureAndDosPanelComponent(PanelComponent): From a372ea62fabbc0f831b2fa6f0ffdce828de6b41b Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Tue, 21 Oct 2025 11:16:55 -0700 Subject: [PATCH 09/24] add .venv to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index c738b02b..f98e7d76 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ build/* *.egg-info dist/* .vscode + +.venv/ +_version.py From 6ddc66d138fbff0c132214d2ac13498c9ac4a0f8 Mon Sep 17 00:00:00 2001 From: Patrick Huck Date: Wed, 22 Oct 2025 10:45:27 -0700 Subject: [PATCH 10/24] linting --- crystal_toolkit/components/phonon.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index f0808a0a..c8ce895b 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -61,7 +61,7 @@ def __init__( **kwargs, ) - bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( + bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( self.initial_data["default"] ) self.create_store("bs-store", bs) @@ -108,9 +108,11 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style={"width": "200px"} - if show_path_options - else {"maxWidth": "200", "display": "none"}, + style=( + {"width": "200px"} + if show_path_options + else {"maxWidth": "200", "display": "none"} + ), id=self.id("path-container"), ) @@ -125,9 +127,11 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style={"width": "200px"} - if show_path_options - else {"width": "200px", "display": "none"}, + style=( + {"width": "200px"} + if show_path_options + else {"width": "200px", "display": "none"} + ), id=self.id("label-container"), ) @@ -541,7 +545,7 @@ def _get_data_list_dict( target="blank", ), ] - ): "Yes" if bs.has_nac else "No", + ): ("Yes" if bs.has_nac else "No"), "Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No", "Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No", "Min frequency": min_freq_report, From be3e06e19fb4f23bd4326041ed461e13f3aa3c8b Mon Sep 17 00:00:00 2001 From: Patrick Huck Date: Wed, 22 Oct 2025 10:45:27 -0700 Subject: [PATCH 11/24] linting --- crystal_toolkit/components/phonon.py | 163 +++++++++------------------ 1 file changed, 53 insertions(+), 110 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index c8ce895b..004e9335 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from copy import deepcopy from typing import TYPE_CHECKING, Any import numpy as np @@ -28,9 +29,7 @@ from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.dos import CompleteDos -# Author: Jason Munro, Janosh Riebesell -# Contact: jmunro@lbl.gov, janosh@lbl.gov - +DISPLACE_COEF = [0, 1, 0, -1, 0] # TODOs: # - look for additional projection methods in phonon DOS (currently only atom @@ -214,131 +213,75 @@ def _get_eigendisplacement( precision: int = 15, magnitude: int = 15, ) -> dict: - if not ph_bs: + if not ph_bs or not json_data: return {} - # get displacement - min_bond_length = float("inf") - for content_idx in range(len(json_data["contents"][1]["contents"])): - for pair_idx in range( - len(json_data["contents"][1]["contents"][content_idx]["_meta"]) - ): - u, v = json_data["contents"][1]["contents"][content_idx][ - "positionPairs" - ][pair_idx] - # Convert to numpy arrays - u = np.array(u) - v = np.array(v) - length = np.linalg.norm(v - u) - min_bond_length = min(min_bond_length, length) - - # atom animate assert json_data["contents"][0]["name"] == "atoms" - for content_idx in range(len(json_data["contents"][0]["contents"])): - atom_idx = json_data["contents"][0]["contents"][content_idx]["_meta"][0] - - raw_displacement = ph_bs.eigendisplacements[band][qpoint][atom_idx] + assert json_data["contents"][1]["name"] == "bonds" + rdata = deepcopy(json_data) - displacement = [complex(vec).real * magnitude for vec in raw_displacement] + def calc_displacement(idx: int) -> list: + return [ + round(complex(vec).real * magnitude, precision) + for vec in ph_bs.eigendisplacements[band][qpoint][idx] + ] - position_animation = [] - for displace_coef in [0, 1, 0, -1, 0]: - displace = [ - round(displace_coef * magnitude * d, precision) - for d in displacement - ] - position_animation.append(displace) + def calc_animation_step(displacement: list, coef: int) -> list: + return [round(coef * magnitude * d, precision) for d in displacement] - json_data["contents"][0]["contents"][content_idx]["animate"] = ( - position_animation - ) - json_data["contents"][0]["contents"][content_idx]["keyframes"] = [ - 0, - 1, - 2, - 3, - 4, + # atom animate + contents0 = json_data["contents"][0]["contents"] + for cidx, content in enumerate(contents0): + displacement = calc_displacement(content["_meta"][0]) + rcontent = rdata["contents"][0]["contents"][cidx] + rcontent["animate"] = [ + calc_animation_step(displacement, coef) for coef in DISPLACE_COEF ] - json_data["contents"][0]["contents"][content_idx]["animateType"] = ( - "displacement" - ) + rcontent["keyframes"] = list(range(5)) + rcontent["animateType"] = "displacement" - # bond animate - assert json_data["contents"][1]["name"] == "bonds" - for content_idx in range(len(json_data["contents"][1]["contents"])): + # get displacement and bond animate + min_bond_length = float("inf") + contents1 = json_data["contents"][1]["contents"] + for cidx, content in enumerate(contents1): bond_animation = [] + assert len(content["_meta"]) == len(content["positionPairs"]) - assert len( - json_data["contents"][1]["contents"][content_idx]["_meta"] - ) == len(json_data["contents"][1]["contents"][content_idx]["positionPairs"]) - - for pair_idx in range( - len(json_data["contents"][1]["contents"][content_idx]["_meta"]) - ): - u_idx, v_idx = json_data["contents"][1]["contents"][content_idx][ - "_meta" - ][pair_idx] - - # u - u_raw_displacement = ph_bs.eigendisplacements[band][qpoint][u_idx] - u_displacement = [ - round(complex(vec).real * magnitude, precision) - for vec in u_raw_displacement - ] - - # v - v_raw_displacement = ph_bs.eigendisplacements[band][qpoint][v_idx] - v_displacement = [ - round(complex(vec).real * magnitude, precision) - for vec in v_raw_displacement - ] + for pair in enumerate(content["_meta"]): + u, v = rdata["contents"][1]["contents"][cidx]["positionPairs"] = list( + map(np.array, pair) + ) + length = np.linalg.norm(v - u) + min_bond_length = min(min_bond_length, length) + displacements = list(map(calc_displacement, pair)) + u_to_middle_bond_animation = [] - # only draw in unit cell - u_to_middle_bond_animation = [] # u to middle - # v_to_middle_bond_animation = [] # v to middle - for displace_coef in [0, 1, 0, -1, 0]: - u_end_displacement = [ - round(displace_coef * magnitude * d, precision) - for d in u_displacement - ] - v_end_displacement = [ - round(displace_coef * magnitude * d, precision) - for d in v_displacement - ] + for coef in DISPLACE_COEF: middle_end_displacement = ( - (np.array(u_end_displacement) + np.array(v_end_displacement)) + np.add( + np.array(calc_animation_step(displacement, coef)) + for displacement in displacements + ) / 2 - ).tolist() - middle_end_displacement = [ - round(dis, precision) for dis in middle_end_displacement - ] - - u2middle_animation = [u_end_displacement, middle_end_displacement] - # v2middle_animation = [v_end_displacement, middle_end_displacement] - - u_to_middle_bond_animation.append(u2middle_animation) - # v_to_middle_bond_animation.append(v2middle_animation) + ) + u_to_middle_bond_animation.append( + [ + displacements[0], + [round(dis, precision) for dis in middle_end_displacement], + ] + ) bond_animation.append(u_to_middle_bond_animation) - json_data["contents"][1]["contents"][content_idx]["animate"] = ( - bond_animation - ) - json_data["contents"][1]["contents"][content_idx]["keyframes"] = [ - 0, - 1, - 2, - 3, - 4, - ] - json_data["contents"][1]["contents"][content_idx]["animateType"] = ( - "displacement" - ) + + rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation + rdata["contents"][1]["contents"][cidx]["keyframes"] = list(range(5)) + rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" # remove polyhedra manually - json_data["contents"][2]["visible"] = False - json_data["contents"][3]["visible"] = False + for i in range(2, 4): + rdata["contents"][i]["visible"] = False - return json_data + return rdata @staticmethod def _get_ph_bs_dos( From 19573d0713cc1a4716620a3a59a5c25022911de6 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Thu, 23 Oct 2025 12:56:13 -0700 Subject: [PATCH 12/24] fix irregular bonding --- crystal_toolkit/components/phonon.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 6c2fcfea..e7ddc5ce 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -283,6 +283,7 @@ def calc_animation_step(max_displacement: list, coef: int) -> list: # [mid_x_displacement, mid_y_displacement, mid_z_displacement] # ]. contents1 = json_data["contents"][1]["contents"] + for cidx, content in enumerate(contents1): bond_animation = [] assert len(content["_meta"]) == len(content["positionPairs"]) @@ -294,26 +295,17 @@ def calc_animation_step(max_displacement: list, coef: int) -> list: u_to_middle_bond_animation = [] - for frame_idx, coef in enumerate(DISPLACE_COEF): + for coef in DISPLACE_COEF: # Calculate the midpoint displacement between atom u and v for each animation frame. - middle_end_displacement = ( - np.add( - *( - [ - np.array( - calc_animation_step(max_displacement, coef) - ) - for max_displacement in max_displacements - ] - ) - ) - / 2 - ) + u_displacement, v_displacement = [ + np.array(calc_animation_step(max_displacement, coef)) + for max_displacement in max_displacements + ] + middle_end_displacement = np.add(u_displacement, v_displacement) / 2 + u_to_middle_bond_animation.append( [ - rdata["contents"][0]["contents"][atom_idx_pair[0]][ - "animate" - ][frame_idx], # u atom displacement + u_displacement, # u atom displacement [ round(dis, precision) for dis in middle_end_displacement ], # middle point displacement From c706707c33b2875d98cac3a093345b7a59526516 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Fri, 24 Oct 2025 13:16:11 -0700 Subject: [PATCH 13/24] add supercell construction --- crystal_toolkit/components/phonon.py | 191 ++++++++++++++++++++++++--- 1 file changed, 170 insertions(+), 21 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index e7ddc5ce..962764e0 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -18,6 +18,7 @@ from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.phonon.plotter import PhononBSPlotter +from pymatgen.transformations.standard_transformations import SupercellTransformation from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent @@ -30,6 +31,11 @@ from pymatgen.electronic_structure.dos import CompleteDos DISPLACE_COEF = [0, 1, 0, -1, 0] +MARKER_COLOR = "red" +MARKER_SIZE = 12 +MARKER_SHAPE = "x" +MAX_MAGNITUDE = 400 +MIN_MAGNITUDE = 0 # TODOs: # - look for additional projection methods in phonon DOS (currently only atom @@ -149,22 +155,70 @@ def _sub_layouts(self) -> dict[str, Component]: # crystal visualization - tip = html.P( - "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", - id=self.id("crystal-tip"), - style={ - "margin": "0 0 12px", - "fontSize": "16px", - "color": "#555", - "textAlign": "center", - }, + tip = html.H5( + "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!", + ) + + crystal_animation = html.Div( + CrystalToolkitAnimationScene( + data={}, + sceneSize="500px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.2}, + ), + style={"width": "60%"}, ) - crystal_animation = CrystalToolkitAnimationScene( - data={}, - sceneSize="200px", - id=self.id("crystal-animation"), - settings={"defaultZoom": 1.5}, + crystal_animation_controls = html.Div( + [ + html.Br(), + html.Div(tip, style={"textAlign": "center"}), + html.Br(), + html.H5("Control Panel", style={"textAlign": "center"}), + html.H6("Supercell modification"), + html.Br(), + html.Div( + [ + self.get_numerical_input( + kwarg_label="scale-x", + default=1, + is_int=True, + label="x", + style={"width": "5rem"}, + ), + self.get_numerical_input( + kwarg_label="scale-y", + default=1, + is_int=True, + label="y", + style={"width": "5rem"}, + ), + self.get_numerical_input( + kwarg_label="scale-z", + default=1, + is_int=True, + label="z", + style={"width": "5rem"}, + ), + html.Button( + "Update", + id=self.id("controls-btn"), + style={"height": "40px"}, + ), + ], + style={"display": "flex"}, + ), + html.Br(), + html.Div( + self.get_slider_input( + kwarg_label="magnitude", + default=0.5, + step=0.01, + domain=[0, 1], + label="Vibration magnitude", + ) + ), + ], ) return { @@ -176,12 +230,26 @@ def _sub_layouts(self) -> dict[str, Component]: "table": summary_table, "crystal-animation": crystal_animation, "tip": tip, + "crystal-animation-controls": crystal_animation_controls, } def layout(self) -> html.Div: sub_layouts = self._sub_layouts crystal_animation = Columns( - [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + [ + Column( + [ + # sub_layouts["tip"], + Columns( + [ + sub_layouts["crystal-animation"], + sub_layouts["crystal-animation-controls"], + ] + ) + ] + ), + # Column([sub_layouts["crystal-animation-controls"]]) + ] ) graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( @@ -212,6 +280,7 @@ def _get_eigendisplacement( qpoint: int = 0, precision: int = 15, magnitude: int = 225, + total_repeat_cell_cnt: int | None = None, ) -> dict: if not ph_bs or not json_data: return {} @@ -233,9 +302,16 @@ def calc_max_displacement(idx: int) -> list: This function extracts the real component of the atom's eigendisplacement, scales it by the specified magnitude, and returns the resulting vector. """ + + # get the atom index + assert total_repeat_cell_cnt != 0 + modified_idx = ( + (idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx + ) + return [ round(complex(vec).real * magnitude, precision) - for vec in ph_bs.eigendisplacements[band][qpoint][idx] + for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx] ] def calc_animation_step(max_displacement: list, coef: int) -> list: @@ -717,7 +793,27 @@ def get_figure( clickmode="event+select", ) - figure = {"data": bs_traces + dos_traces, "layout": layout} + default_red_dot = [ + { + "type": "scatter", + "mode": "markers", + "x": [0], + "y": [0], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[0, 0]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ] + + figure = {"data": bs_traces + dos_traces + default_red_dot, "layout": layout} legend = dict( x=1.02, @@ -743,14 +839,46 @@ def generate_callbacks(self, app, cache) -> None: Output(self.id("table"), "children"), Input(self.id("ph_bs"), "data"), Input(self.id("ph_dos"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), ) - def update_graph(bs, dos): + def update_graph(bs, dos, nclick): if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) if isinstance(dos, dict): dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) + if nclick and nclick.get("points"): + # remove marker if there is one + figure["data"] = [ + t for t in figure["data"] if t.get("name") != "click-marker" + ] + + x_click = nclick["points"][0]["x"] + y_click = nclick["points"][0]["y"] + + pt = nclick["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + figure["data"].append( + { + "type": "scatter", + "mode": "markers", + "x": [x_click], + "y": [y_click], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[qpoint, band_num]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ) zone_scene = self.get_brillouin_zone_scene(bs) @@ -775,18 +903,37 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): Output(self.id("crystal-animation"), "data"), Input(self.id("ph-bsdos-graph"), "clickData"), Input(self.id("ph_bs"), "data"), + Input(self.id("controls-btn"), "n_clicks"), + Input(self.get_all_kwargs_id(), "value"), # prevent_initial_call=True ) - def update_crystal_animation(cd, bs): + def update_crystal_animation(cd, bs, update, kwargs): if not bs: raise PreventUpdate if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) - struc_graph = StructureGraph.from_local_env_strategy( - bs.structure, CrystalNN() + kwargs = self.reconstruct_kwargs_from_state() + + # animation control + scale_x, scale_y, scale_z = ( + int(kwargs["scale-x"]), + int(kwargs["scale-y"]), + int(kwargs["scale-z"]), + ) + magnitude_fraction = kwargs["magnitude"] + magnitude = ( + MAX_MAGNITUDE - MIN_MAGNITUDE + ) * magnitude_fraction + MIN_MAGNITUDE + + # create supercell + trans = SupercellTransformation( + ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) ) + struct = trans.apply_transformation(bs.structure) + + struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) scene = struc_graph.get_scene( draw_image_atoms=False, bonded_sites_outside_unit_cell=False, @@ -806,6 +953,8 @@ def update_crystal_animation(cd, bs): json_data=json_data, band=band_num, qpoint=qpoint, + total_repeat_cell_cnt=scale_x * scale_y * scale_z, + magnitude=magnitude, ) From 933ddc351985c4c231c28ce534fbef67a1f36c8b Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Fri, 24 Oct 2025 16:47:44 -0700 Subject: [PATCH 14/24] merged with main and add supercell construction --- crystal_toolkit/components/phonon.py | 191 ++++++++++++++++++++++++--- 1 file changed, 170 insertions(+), 21 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index e7ddc5ce..962764e0 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -18,6 +18,7 @@ from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.phonon.plotter import PhononBSPlotter +from pymatgen.transformations.standard_transformations import SupercellTransformation from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent @@ -30,6 +31,11 @@ from pymatgen.electronic_structure.dos import CompleteDos DISPLACE_COEF = [0, 1, 0, -1, 0] +MARKER_COLOR = "red" +MARKER_SIZE = 12 +MARKER_SHAPE = "x" +MAX_MAGNITUDE = 400 +MIN_MAGNITUDE = 0 # TODOs: # - look for additional projection methods in phonon DOS (currently only atom @@ -149,22 +155,70 @@ def _sub_layouts(self) -> dict[str, Component]: # crystal visualization - tip = html.P( - "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", - id=self.id("crystal-tip"), - style={ - "margin": "0 0 12px", - "fontSize": "16px", - "color": "#555", - "textAlign": "center", - }, + tip = html.H5( + "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!", + ) + + crystal_animation = html.Div( + CrystalToolkitAnimationScene( + data={}, + sceneSize="500px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.2}, + ), + style={"width": "60%"}, ) - crystal_animation = CrystalToolkitAnimationScene( - data={}, - sceneSize="200px", - id=self.id("crystal-animation"), - settings={"defaultZoom": 1.5}, + crystal_animation_controls = html.Div( + [ + html.Br(), + html.Div(tip, style={"textAlign": "center"}), + html.Br(), + html.H5("Control Panel", style={"textAlign": "center"}), + html.H6("Supercell modification"), + html.Br(), + html.Div( + [ + self.get_numerical_input( + kwarg_label="scale-x", + default=1, + is_int=True, + label="x", + style={"width": "5rem"}, + ), + self.get_numerical_input( + kwarg_label="scale-y", + default=1, + is_int=True, + label="y", + style={"width": "5rem"}, + ), + self.get_numerical_input( + kwarg_label="scale-z", + default=1, + is_int=True, + label="z", + style={"width": "5rem"}, + ), + html.Button( + "Update", + id=self.id("controls-btn"), + style={"height": "40px"}, + ), + ], + style={"display": "flex"}, + ), + html.Br(), + html.Div( + self.get_slider_input( + kwarg_label="magnitude", + default=0.5, + step=0.01, + domain=[0, 1], + label="Vibration magnitude", + ) + ), + ], ) return { @@ -176,12 +230,26 @@ def _sub_layouts(self) -> dict[str, Component]: "table": summary_table, "crystal-animation": crystal_animation, "tip": tip, + "crystal-animation-controls": crystal_animation_controls, } def layout(self) -> html.Div: sub_layouts = self._sub_layouts crystal_animation = Columns( - [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + [ + Column( + [ + # sub_layouts["tip"], + Columns( + [ + sub_layouts["crystal-animation"], + sub_layouts["crystal-animation-controls"], + ] + ) + ] + ), + # Column([sub_layouts["crystal-animation-controls"]]) + ] ) graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( @@ -212,6 +280,7 @@ def _get_eigendisplacement( qpoint: int = 0, precision: int = 15, magnitude: int = 225, + total_repeat_cell_cnt: int | None = None, ) -> dict: if not ph_bs or not json_data: return {} @@ -233,9 +302,16 @@ def calc_max_displacement(idx: int) -> list: This function extracts the real component of the atom's eigendisplacement, scales it by the specified magnitude, and returns the resulting vector. """ + + # get the atom index + assert total_repeat_cell_cnt != 0 + modified_idx = ( + (idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx + ) + return [ round(complex(vec).real * magnitude, precision) - for vec in ph_bs.eigendisplacements[band][qpoint][idx] + for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx] ] def calc_animation_step(max_displacement: list, coef: int) -> list: @@ -717,7 +793,27 @@ def get_figure( clickmode="event+select", ) - figure = {"data": bs_traces + dos_traces, "layout": layout} + default_red_dot = [ + { + "type": "scatter", + "mode": "markers", + "x": [0], + "y": [0], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[0, 0]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ] + + figure = {"data": bs_traces + dos_traces + default_red_dot, "layout": layout} legend = dict( x=1.02, @@ -743,14 +839,46 @@ def generate_callbacks(self, app, cache) -> None: Output(self.id("table"), "children"), Input(self.id("ph_bs"), "data"), Input(self.id("ph_dos"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), ) - def update_graph(bs, dos): + def update_graph(bs, dos, nclick): if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) if isinstance(dos, dict): dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) + if nclick and nclick.get("points"): + # remove marker if there is one + figure["data"] = [ + t for t in figure["data"] if t.get("name") != "click-marker" + ] + + x_click = nclick["points"][0]["x"] + y_click = nclick["points"][0]["y"] + + pt = nclick["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + figure["data"].append( + { + "type": "scatter", + "mode": "markers", + "x": [x_click], + "y": [y_click], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[qpoint, band_num]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ) zone_scene = self.get_brillouin_zone_scene(bs) @@ -775,18 +903,37 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): Output(self.id("crystal-animation"), "data"), Input(self.id("ph-bsdos-graph"), "clickData"), Input(self.id("ph_bs"), "data"), + Input(self.id("controls-btn"), "n_clicks"), + Input(self.get_all_kwargs_id(), "value"), # prevent_initial_call=True ) - def update_crystal_animation(cd, bs): + def update_crystal_animation(cd, bs, update, kwargs): if not bs: raise PreventUpdate if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) - struc_graph = StructureGraph.from_local_env_strategy( - bs.structure, CrystalNN() + kwargs = self.reconstruct_kwargs_from_state() + + # animation control + scale_x, scale_y, scale_z = ( + int(kwargs["scale-x"]), + int(kwargs["scale-y"]), + int(kwargs["scale-z"]), + ) + magnitude_fraction = kwargs["magnitude"] + magnitude = ( + MAX_MAGNITUDE - MIN_MAGNITUDE + ) * magnitude_fraction + MIN_MAGNITUDE + + # create supercell + trans = SupercellTransformation( + ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) ) + struct = trans.apply_transformation(bs.structure) + + struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) scene = struc_graph.get_scene( draw_image_atoms=False, bonded_sites_outside_unit_cell=False, @@ -806,6 +953,8 @@ def update_crystal_animation(cd, bs): json_data=json_data, band=band_num, qpoint=qpoint, + total_repeat_cell_cnt=scale_x * scale_y * scale_z, + magnitude=magnitude, ) From 6942117d1d0e16ae68679fd3a92ffbc9d4fa333a Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Fri, 24 Oct 2025 17:01:34 -0700 Subject: [PATCH 15/24] add constraints on positive supercell construction --- crystal_toolkit/components/phonon.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 962764e0..3a594a15 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -184,6 +184,7 @@ def _sub_layouts(self) -> dict[str, Component]: default=1, is_int=True, label="x", + min=1, style={"width": "5rem"}, ), self.get_numerical_input( @@ -191,6 +192,7 @@ def _sub_layouts(self) -> dict[str, Component]: default=1, is_int=True, label="y", + min=1, style={"width": "5rem"}, ), self.get_numerical_input( @@ -198,6 +200,7 @@ def _sub_layouts(self) -> dict[str, Component]: default=1, is_int=True, label="z", + min=1, style={"width": "5rem"}, ), html.Button( From 14ca45487f67fd165ba153d049a2a8def00ebed6 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Fri, 24 Oct 2025 17:46:58 -0700 Subject: [PATCH 16/24] remove default_red_dot with DRY principle --- crystal_toolkit/components/phonon.py | 78 +++++++++++----------------- 1 file changed, 29 insertions(+), 49 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 3a594a15..2e316316 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -796,27 +796,7 @@ def get_figure( clickmode="event+select", ) - default_red_dot = [ - { - "type": "scatter", - "mode": "markers", - "x": [0], - "y": [0], - "marker": { - "color": MARKER_COLOR, - "size": MARKER_SIZE, - "symbol": MARKER_SHAPE, - }, - "name": "click-marker", - "showlegend": False, - "customdata": [[0, 0]], - "hovertemplate": ( - "band: %{customdata[1]}
q-point: %{customdata[0]}
" - ), - } - ] - - figure = {"data": bs_traces + dos_traces + default_red_dot, "layout": layout} + figure = {"data": bs_traces + dos_traces, "layout": layout} legend = dict( x=1.02, @@ -851,37 +831,37 @@ def update_graph(bs, dos, nclick): dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - if nclick and nclick.get("points"): - # remove marker if there is one - figure["data"] = [ - t for t in figure["data"] if t.get("name") != "click-marker" - ] - x_click = nclick["points"][0]["x"] - y_click = nclick["points"][0]["y"] + # remove marker if there is one + figure["data"] = [ + t for t in figure["data"] if t.get("name") != "click-marker" + ] - pt = nclick["points"][0] - qpoint, band_num = pt.get("customdata", [0, 0]) + x_click = nclick["points"][0]["x"] if nclick else 0 + y_click = nclick["points"][0]["y"] if nclick else 0 + pt = nclick["points"][0] if nclick else {} - figure["data"].append( - { - "type": "scatter", - "mode": "markers", - "x": [x_click], - "y": [y_click], - "marker": { - "color": MARKER_COLOR, - "size": MARKER_SIZE, - "symbol": MARKER_SHAPE, - }, - "name": "click-marker", - "showlegend": False, - "customdata": [[qpoint, band_num]], - "hovertemplate": ( - "band: %{customdata[1]}
q-point: %{customdata[0]}
" - ), - } - ) + qpoint, band_num = pt.get("customdata", [0, 0]) + + figure["data"].append( + { + "type": "scatter", + "mode": "markers", + "x": [x_click], + "y": [y_click], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[qpoint, band_num]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ) zone_scene = self.get_brillouin_zone_scene(bs) From ad618d220280efc13e34cdf4942e0c842cc86428 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Mon, 27 Oct 2025 13:59:49 -0700 Subject: [PATCH 17/24] add PhononBandstructureAndDosComponent_v2 --- crystal_toolkit/components/__init__.py | 1 + crystal_toolkit/components/phonon.py | 866 ++++++++++++++++++++++++- 2 files changed, 840 insertions(+), 27 deletions(-) diff --git a/crystal_toolkit/components/__init__.py b/crystal_toolkit/components/__init__.py index 1b2cc6a6..d905e2aa 100644 --- a/crystal_toolkit/components/__init__.py +++ b/crystal_toolkit/components/__init__.py @@ -14,6 +14,7 @@ ) from crystal_toolkit.components.phonon import ( PhononBandstructureAndDosComponent, + PhononBandstructureAndDosComponent_v2, PhononBandstructureAndDosPanelComponent, ) from crystal_toolkit.components.pourbaix import PourbaixDiagramComponent diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 2e316316..6231ceaf 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -7,7 +7,7 @@ import numpy as np import plotly.graph_objects as go from dash import dcc, html -from dash.dependencies import Component, Input, Output +from dash.dependencies import Component, Input, Output, State from dash.exceptions import PreventUpdate from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene @@ -34,7 +34,7 @@ MARKER_COLOR = "red" MARKER_SIZE = 12 MARKER_SHAPE = "x" -MAX_MAGNITUDE = 400 +MAX_MAGNITUDE = 300 MIN_MAGNITUDE = 0 # TODOs: @@ -73,6 +73,778 @@ def __init__( self.create_store("bs", None) self.create_store("dos", None) + @property + def _sub_layouts(self) -> dict[str, Component]: + # defaults + state = {"label-select": "sc", "dos-select": "ap"} + + fig = PhononBandstructureAndDosComponent.get_figure(None, None) + # Main plot + graph = dcc.Graph( + figure=fig, + config={"displayModeBar": False}, + responsive=False, + id=self.id("ph-bsdos-graph"), + ) + + # Brillouin zone + zone_scene = self.get_brillouin_zone_scene(None) + zone = CrystalToolkitScene( + data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") + ) + + # Hide by default if not loaded by mpid, switching between k-paths + # on-the-fly only supported for bandstructures retrieved from MP + show_path_options = bool(self.initial_data["default"]["mpid"]) + + options = [ + {"label": "Latimer-Munro", "value": "lm"}, + {"label": "Hinuma et al.", "value": "hin"}, + {"label": "Setyawan-Curtarolo", "value": "sc"}, + ] + # Convention selection for band structure + convention = html.Div( + [ + self.get_choice_input( + kwarg_label="path-convention", + state=state, + label="Path convention", + help_str="Convention to choose path in k-space", + options=options, + ) + ], + style=( + {"width": "200px"} + if show_path_options + else {"maxWidth": "200", "display": "none"} + ), + id=self.id("path-container"), + ) + + # Equivalent labels across band structure conventions + label_select = html.Div( + [ + self.get_choice_input( + kwarg_label="label-select", + state=state, + label="Label convention", + help_str="Convention to choose labels for path in k-space", + options=options, + ) + ], + style=( + {"width": "200px"} + if show_path_options + else {"width": "200px", "display": "none"} + ), + id=self.id("label-container"), + ) + + # Density of states data selection + dos_select = self.get_choice_input( + kwarg_label="dos-select", + state=state, + label="Projection", + help_str="Choose projection", + options=[{"label": "Atom Projected", "value": "ap"}], + style={"width": "200px"}, + ) + + summary_dict = self._get_data_list_dict(None, None) + summary_table = get_data_list(summary_dict) + + # crystal visualization + + tip = html.P( + "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", + id=self.id("crystal-tip"), + style={ + "margin": "0 0 12px", + "fontSize": "16px", + "color": "#555", + "textAlign": "center", + }, + ) + + crystal_animation = CrystalToolkitAnimationScene( + data={}, + sceneSize="200px", + id=self.id("crystal-animation"), + settings={"defaultZoom": 1.5}, + axisView="SW", + showControls=False, # disable download for now + ) + + return { + "graph": graph, + "convention": convention, + "dos-select": dos_select, + "label-select": label_select, + "zone": zone, + "table": summary_table, + "crystal-animation": crystal_animation, + "tip": tip, + } + + def layout(self) -> html.Div: + sub_layouts = self._sub_layouts + crystal_animation = Columns( + [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] + ) + graph = Columns([Column([sub_layouts["graph"]])]) + controls = Columns( + [ + Column( + [ + sub_layouts["convention"], + sub_layouts["label-select"], + sub_layouts["dos-select"], + ] + ) + ] + ) + brillouin_zone = Columns( + [ + Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), + Column([Label("Brillouin Zone"), sub_layouts["zone"]]), + ] + ) + + return html.Div([graph, crystal_animation, controls, brillouin_zone]) + + @staticmethod + def _get_eigendisplacement( + ph_bs: BandStructureSymmLine, + json_data: dict, + band: int = 0, + qpoint: int = 0, + precision: int = 15, + magnitude: int = MAX_MAGNITUDE / 2, + ) -> dict: + if not ph_bs or not json_data: + return {} + + assert json_data["contents"][0]["name"] == "atoms" + assert json_data["contents"][1]["name"] == "bonds" + rdata = deepcopy(json_data) + + def calc_max_displacement(idx: int) -> list: + """ + Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement. + + Parameters: + idx (int): The atom index. + + Returns: + list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement] + + This function extracts the real component of the atom's eigendisplacement, + scales it by the specified magnitude, and returns the resulting vector. + """ + return [ + round(complex(vec).real * magnitude, precision) + for vec in ph_bs.eigendisplacements[band][qpoint][idx] + ] + + def calc_animation_step(max_displacement: list, coef: int) -> list: + """ + Calculate the displacement for an animation frame based on the given coefficient. + + Parameters: + max_displacement (list): A list of maximum displacements along each axis, + formatted as [x_max_displacement, y_max_displacement, z_max_displacement]. + coef (int): A coefficient indicating the motion direction. + - 0: no movement + - 1: forward movement + - -1: backward movement + + Returns: + list: The displacement vector [x_displacement, y_displacement, z_displacement]. + + This function generates oscillatory motion by scaling the maximum displacement + with the provided coefficient. + """ + return [round(coef * md, precision) for md in max_displacement] + + # Compute per-frame atomic motion. + # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates. + contents0 = json_data["contents"][0]["contents"] + for cidx, content in enumerate(contents0): + max_displacement = calc_max_displacement(content["_meta"][0]) + rcontent = rdata["contents"][0]["contents"][cidx] + # put animation frame to the given atom index + rcontent["animate"] = [ + calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF + ] + rcontent["keyframes"] = list(range(len(DISPLACE_COEF))) + rcontent["animateType"] = "displacement" + # Compute per-frame bonding motion. + # Explanation: + # Each bond connects two atoms, `u` and `v`, represented as (u)----(v) + # To model the bond motion, it is divided into two segments: + # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v) + # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid). + # For each cylinder, displacements are assigned to the endpoints — for example, + # the (u)--(mid) cylinder uses: + # [ + # [u_x_displacement, u_y_displacement, u_z_displacement], + # [mid_x_displacement, mid_y_displacement, mid_z_displacement] + # ]. + contents1 = json_data["contents"][1]["contents"] + + for cidx, content in enumerate(contents1): + bond_animation = [] + assert len(content["_meta"]) == len(content["positionPairs"]) + + for atom_idx_pair in content["_meta"]: + max_displacements = list( + map(calc_max_displacement, atom_idx_pair) + ) # max displacement for u and v + + u_to_middle_bond_animation = [] + + for coef in DISPLACE_COEF: + # Calculate the midpoint displacement between atom u and v for each animation frame. + u_displacement, v_displacement = [ + np.array(calc_animation_step(max_displacement, coef)) + for max_displacement in max_displacements + ] + middle_end_displacement = np.add(u_displacement, v_displacement) / 2 + + u_to_middle_bond_animation.append( + [ + u_displacement, # u atom displacement + [ + round(dis, precision) for dis in middle_end_displacement + ], # middle point displacement + ] + ) + + bond_animation.append(u_to_middle_bond_animation) + + rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation + rdata["contents"][1]["contents"][cidx]["keyframes"] = list( + range(len(DISPLACE_COEF)) + ) + rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" + + # remove unused sense + for i in range(2, 4): + rdata["contents"][i]["visible"] = False + + return rdata + + @staticmethod + def _get_ph_bs_dos( + data: dict[str, Any] | None, + ) -> tuple[PhononBandStructureSymmLine, CompletePhononDos]: + data = data or {} + + # this component can be loaded either from mpid or + # directly from BandStructureSymmLine or CompleteDos objects + # if mpid is supplied, it takes precedence + + mpid = data.get("mpid") + bandstructure_symm_line = data.get("bandstructure_symm_line") + density_of_states = data.get("density_of_states") + + if not mpid and (bandstructure_symm_line is None or density_of_states is None): + return None, None + + if mpid: + with MPRester() as mpr: + try: + bandstructure_symm_line = ( + mpr.get_phonon_bandstructure_by_material_id(mpid) + ) + except Exception as exc: + print(exc) + bandstructure_symm_line = None + + try: + density_of_states = mpr.get_phonon_dos_by_material_id(mpid) + except Exception as exc: + print(exc) + density_of_states = None + + else: + if bandstructure_symm_line and isinstance(bandstructure_symm_line, dict): + bandstructure_symm_line = PhononBandStructureSymmLine.from_dict( + bandstructure_symm_line + ) + + if density_of_states and isinstance(density_of_states, dict): + density_of_states = CompletePhononDos.from_dict(density_of_states) + + return bandstructure_symm_line, density_of_states + + @staticmethod + def get_brillouin_zone_scene(bs: PhononBandStructureSymmLine) -> Scene: + if not bs: + return Scene(name="brillouin_zone", contents=[]) + + # TODO: from BSPlotter, merge back into BSPlotter + # Brillouin zone + bz_lattice = bs.structure.lattice.reciprocal_lattice + bz = bz_lattice.get_wigner_seitz_cell() + lines = [] + for iface in range(len(bz)): # pylint: disable=C0200 + for line in itertools.combinations(bz[iface], 2): + for jface in range(len(bz)): + if ( + iface < jface + and any(np.all(line[0] == x) for x in bz[jface]) + and any(np.all(line[1] == x) for x in bz[jface]) + ): + lines += [list(line[0]), list(line[1])] + + zone_lines = Lines(positions=lines) + zone_surface = Convex(positions=lines, opacity=0.05, color="#000000") + + labels = {} + for qpt in bs.qpoints: + if qpt.label: + label = qpt.label + for orig, new in pretty_labels.items(): + label = label.replace(orig, new) + labels[label] = bz_lattice.get_cartesian_coords(qpt.frac_coords) + label_list = [ + Spheres(positions=[coords], tooltip=label, radius=0.03, color="#5EB1BF") + for label, coords in labels.items() + ] + + path = [] + cylinder_pairs = [] + for b in bs.branches: + start = bz_lattice.get_cartesian_coords( + bs.qpoints[b["start_index"]].frac_coords + ) + end = bz_lattice.get_cartesian_coords( + bs.qpoints[b["end_index"]].frac_coords + ) + path += [start, end] + cylinder_pairs += [[start, end]] + # path_lines = Lines(positions=path, color="#ff4b5c") + path_lines = Cylinders( + positionPairs=cylinder_pairs, color="#5EB1BF", radius=0.01 + ) + ibz_region = Convex(positions=path, opacity=0.2, color="#5EB1BF") + + contents = [zone_lines, zone_surface, path_lines, ibz_region, *label_list] + + return Scene(name="brillouin_zone", contents=contents) + + @staticmethod + def get_ph_bandstructure_traces(bs, freq_range): + bs_reg_plot = PhononBSPlotter(bs) + + bs_data = bs_reg_plot.bs_plot_data() + + bands = [] + for band_num in range(bs.nb_bands): + for segment in bs_data["frequency"]: + if any(v <= freq_range[1] for v in segment[band_num]) and any( + v >= freq_range[0] for v in segment[band_num] + ): + bands.append(band_num) # noqa: PERF401 + + bs_traces = [] + + for d, dist_val in enumerate(bs_data["distances"]): + x_dat = dist_val + + traces_for_segment = [] + + segment = bs_data["frequency"][d] + + traces_for_segment += [ + { + "x": x_dat, + "y": segment[band_num], + "mode": "lines", + "line": {"color": "#1f77b4"}, + "hoverinfo": "skip", + "name": "Total", + "customdata": [[di, band_num] for di in range(len(x_dat))], + "hovertemplate": "%{y:.2f} THz", + "showlegend": False, + "xaxis": "x", + "yaxis": "y", + } + for band_num in bands + ] + + bs_traces += traces_for_segment + + for entry_num in range(len(bs_data["ticks"]["label"])): + for key in pretty_labels: + if key in bs_data["ticks"]["label"][entry_num]: + bs_data["ticks"]["label"][entry_num] = bs_data["ticks"]["label"][ + entry_num + ].replace(key, pretty_labels[key]) + + # Vertical lines for disjointed segments + for dist_val, tick_label in zip( + bs_data["ticks"]["distance"], bs_data["ticks"]["label"] + ): + vert_trace = [ + { + "x": [dist_val, dist_val], + "y": freq_range, + "mode": "lines", + "marker": { + "color": "#F5F5F5" if "|" not in tick_label else "white" + }, + "line": {"width": 0.5 if "|" not in tick_label else 2}, + "hoverinfo": "skip", + "showlegend": False, + "xaxis": "x", + "yaxis": "y", + } + ] + + bs_traces += vert_trace + + return bs_traces, bs_data + + @staticmethod + def _get_data_list_dict( + bs: PhononBandStructureSymmLine, dos: CompletePhononDos + ) -> dict[str, str | bool | int]: + if (not bs) and (not dos): + return {} + + bs_minpoint, bs_min_freq = bs.min_freq() + min_freq_report = ( + f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" + ) + if bs_minpoint.label is not None: + label = f" ({bs_minpoint.label})" + for orig, new in pretty_labels.items(): + label = label.replace(orig, new) + min_freq_report += label + + f" at q-point=${bs_minpoint.label}$ (frac. coords. = {bs_minpoint.frac_coords})" + + summary_dict: dict[str, str | bool | int] = { + "Number of bands": f"{bs.nb_bands:,}", + "Number of q-points": f"{bs.nb_qpoints:,}", + # for NAC see https://phonopy.github.io/phonopy/formulation.html#non-analytical-term-correction + Label( + [ + "Has ", + html.A( + "NAC", + href="https://phonopy.github.io/phonopy/formulation.html#non-analytical-term-correction", + target="blank", + ), + ] + ): ("Yes" if bs.has_nac else "No"), + "Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No", + "Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No", + "Min frequency": min_freq_report, + "max frequency": f"{max(dos.frequencies):.2f} THz", + } + + return summary_dict + + @staticmethod + def get_ph_dos_traces(dos: CompletePhononDos, freq_range: tuple[float, float]): + dos_traces = [] + + dos_max = np.abs(dos.frequencies - freq_range[1]).argmin() + dos_min = np.abs(dos.frequencies - freq_range[0]).argmin() + + tdos_label = "Total DOS" + + # Total DOS + trace_tdos = { + "x": dos.densities[dos_min:dos_max], + "y": dos.frequencies[dos_min:dos_max], + "mode": "lines", + "name": tdos_label, + "line": go.scatter.Line(color="#444444"), + "fill": "tozerox", + "fillcolor": "#C4C4C4", + "legendgroup": "spinup", + "xaxis": "x2", + "yaxis": "y2", + } + + dos_traces.append(trace_tdos) + + # Projected DOS + if isinstance(dos, CompletePhononDos): + colors = [ + "#d62728", # brick red + "#2ca02c", # cooked asparagus green + "#17becf", # blue-teal + "#bcbd22", # curry yellow-green + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + ] + + ele_dos = dos.get_element_dos() # project DOS onto elements + for count, label in enumerate(ele_dos): + spin_up_label = str(label) + + trace = { + "x": ele_dos[label].densities[dos_min:dos_max], + "y": dos.frequencies[dos_min:dos_max], + "mode": "lines", + "name": spin_up_label, + "line": dict(width=2, color=colors[count]), + "xaxis": "x2", + "yaxis": "y2", + } + + dos_traces.append(trace) + + return dos_traces + + @staticmethod + def get_figure( + ph_bs: PhononBandStructureSymmLine | None = None, + ph_dos: CompletePhononDos | None = None, + freq_range: tuple[float | None, float | None] = (None, None), + ) -> go.Figure: + if (not ph_dos) and (not ph_bs): + empty_plot_style = { + "height": 500, + "xaxis": {"visible": False}, + "yaxis": {"visible": False}, + "paper_bgcolor": "rgba(0,0,0,0)", + "plot_bgcolor": "rgba(0,0,0,0)", + } + + return go.Figure(layout=empty_plot_style) + + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + + if ph_bs: + ( + bs_traces, + bs_data, + ) = PhononBandstructureAndDosComponent.get_ph_bandstructure_traces( + ph_bs, freq_range=freq_range + ) + + if ph_dos: + dos_traces = PhononBandstructureAndDosComponent.get_ph_dos_traces( + ph_dos, freq_range=freq_range + ) + + # TODO: add logic to handle if bs_traces and/or dos_traces not present + + rmax_list = [ + max(dos_traces[0]["x"]), + abs(min(dos_traces[0]["x"])), + ] + if len(dos_traces) > 1 and "x" in dos_traces[1] and dos_traces[1]["x"].any(): + rmax_list += [ + max(dos_traces[1]["x"]), + abs(min(dos_traces[1]["x"])), + ] + + rmax = max(rmax_list) + + # -- Add trace data to plots + + in_common_axis_styles = dict( + gridcolor="white", + linecolor="rgb(71,71,71)", + linewidth=2, + showgrid=False, + showline=True, + tickfont=dict(size=16), + ticks="inside", + tickwidth=2, + ) + + xaxis_style = dict( + **in_common_axis_styles, + tickmode="array", + mirror=True, + range=[0, bs_data["ticks"]["distance"][-1]], + ticktext=bs_data["ticks"]["label"], + tickvals=bs_data["ticks"]["distance"], + title=dict(text="Wave Vector", font=dict(size=16)), + zeroline=False, + ) + + yaxis_style = dict( + **in_common_axis_styles, + mirror="ticks", + range=freq_range, + title=dict(text="Frequency (THz)", font=dict(size=16)), + zeroline=True, + zerolinecolor="white", + zerolinewidth=2, + ) + + xaxis_style_dos = dict( + **in_common_axis_styles, + title=dict(text="Density of States", font=dict(size=16)), + zeroline=False, + mirror=True, + range=[0, rmax * 1.1], + zerolinecolor="white", + zerolinewidth=2, + ) + + yaxis_style_dos = dict( + **in_common_axis_styles, + zeroline=True, + showticklabels=False, + mirror="ticks", + zerolinewidth=2, + range=freq_range, + zerolinecolor="white", + matches="y", + anchor="x2", + ) + + layout = dict( + title="", + xaxis1=xaxis_style, + xaxis2=xaxis_style_dos, + yaxis=yaxis_style, + yaxis2=yaxis_style_dos, + showlegend=True, + height=500, + width=1000, + hovermode="closest", + paper_bgcolor="rgba(0,0,0,0)", + plot_bgcolor="rgba(230,230,230,230)", + margin=dict(l=60, b=50, t=50, pad=0, r=30), + clickmode="event+select", + ) + + figure = {"data": bs_traces + dos_traces, "layout": layout} + + legend = dict( + x=1.02, + y=1.005, + xanchor="left", + yanchor="top", + bordercolor="#333", + borderwidth=2, + traceorder="normal", + ) + + figure["layout"]["legend"] = legend + + figure["layout"]["xaxis1"]["domain"] = [0.0, 0.7] + figure["layout"]["xaxis2"]["domain"] = [0.73, 1.0] + + return figure + + def generate_callbacks(self, app, cache) -> None: + @app.callback( + Output(self.id("ph-bsdos-graph"), "figure"), + Output(self.id("zone"), "data"), + Output(self.id("table"), "children"), + Input(self.id("ph_bs"), "data"), + Input(self.id("ph_dos"), "data"), + ) + def update_graph(bs, dos): + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + if isinstance(dos, dict): + dos = CompletePhononDos.from_dict(dos) + + figure = self.get_figure(bs, dos) + + zone_scene = self.get_brillouin_zone_scene(bs) + + summary_dict = self._get_data_list_dict(bs, dos) + summary_table = get_data_list(summary_dict) + + return figure, zone_scene.to_json(), summary_table + + @app.callback( + Output(self.id("brillouin-zone"), "data"), + Input(self.id("ph-bsdos-graph"), "hoverData"), + Input(self.id("ph-bsdos-graph"), "clickData"), + ) + def highlight_bz_on_hover_bs(hover_data, click_data, label_select): + """Highlight the corresponding point/edge of the Brillouin Zone when hovering the band + structure plot. + """ + # TODO: figure out what to return (CSS?) to highlight BZ edge/point + return + + @app.callback( + Output(self.id("crystal-animation"), "data"), + Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("ph_bs"), "data"), + # prevent_initial_call=True + ) + def update_crystal_animation(cd, bs): + if not bs: + raise PreventUpdate + + if isinstance(bs, dict): + bs = PhononBandStructureSymmLine.from_dict(bs) + + struc_graph = StructureGraph.from_local_env_strategy( + bs.structure, CrystalNN() + ) + scene = struc_graph.get_scene( + draw_image_atoms=False, + bonded_sites_outside_unit_cell=False, + site_get_scene_kwargs={"retain_atom_idx": True}, + ) + json_data = scene.to_json() + + qpoint = 0 + band_num = 0 + + if cd and cd.get("points"): + pt = cd["points"][0] + qpoint, band_num = pt.get("customdata", [0, 0]) + + return PhononBandstructureAndDosComponent._get_eigendisplacement( + ph_bs=bs, + json_data=json_data, + band=band_num, + qpoint=qpoint, + ) + + +class PhononBandstructureAndDosComponent_v2(MPComponent): + def __init__( + self, + mpid: str | None = None, + bandstructure_symm_line: BandStructureSymmLine | None = None, + density_of_states: CompleteDos | None = None, + id: str | None = None, + **kwargs, + ) -> None: + # this is a compound component, can be fed by mpid or + # by the BandStructure itself + super().__init__( + id=id, + default_data={ + "mpid": mpid, + "bandstructure_symm_line": bandstructure_symm_line, + "density_of_states": density_of_states, + }, + **kwargs, + ) + + bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + self.create_store("bs-store", bs) + self.create_store("bs", None) + self.create_store("dos", None) + @property def _sub_layouts(self) -> dict[str, Component]: # defaults @@ -165,6 +937,8 @@ def _sub_layouts(self) -> dict[str, Component]: sceneSize="500px", id=self.id("crystal-animation"), settings={"defaultZoom": 1.2}, + axisView="SW", + showControls=False, # disable download for now ), style={"width": "60%"}, ) @@ -205,7 +979,7 @@ def _sub_layouts(self) -> dict[str, Component]: ), html.Button( "Update", - id=self.id("controls-btn"), + id=self.id("supercell-controls-btn"), style={"height": "40px"}, ), ], @@ -282,8 +1056,8 @@ def _get_eigendisplacement( band: int = 0, qpoint: int = 0, precision: int = 15, - magnitude: int = 225, - total_repeat_cell_cnt: int | None = None, + magnitude: int = MAX_MAGNITUDE / 2, + total_repeat_cell_cnt: int = 1, ) -> dict: if not ph_bs or not json_data: return {} @@ -886,35 +1660,68 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): Output(self.id("crystal-animation"), "data"), Input(self.id("ph-bsdos-graph"), "clickData"), Input(self.id("ph_bs"), "data"), - Input(self.id("controls-btn"), "n_clicks"), - Input(self.get_all_kwargs_id(), "value"), + Input(self.id("supercell-controls-btn"), "n_clicks"), + Input( + { + "component_id": "CTPhononSection_phonon_bs", + "kwarg_label": "magnitude", + "idx": "False", + "hint": "slider", + }, + "value", + ), + State( + { + "component_id": "CTPhononSection_phonon_bs", + "kwarg_label": "scale-x", + "idx": "()", + "hint": "()", + }, + "value", + ), + State( + { + "component_id": "CTPhononSection_phonon_bs", + "kwarg_label": "scale-y", + "idx": "()", + "hint": "()", + }, + "value", + ), + State( + { + "component_id": "CTPhononSection_phonon_bs", + "kwarg_label": "scale-z", + "idx": "()", + "hint": "()", + }, + "value", + ), # prevent_initial_call=True ) - def update_crystal_animation(cd, bs, update, kwargs): + def update_crystal_animation( + cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z + ): + # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields, + # ensuring updates occur only after the `supercell-controls-btn`` is clicked. + if not bs: raise PreventUpdate if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) - kwargs = self.reconstruct_kwargs_from_state() - - # animation control - scale_x, scale_y, scale_z = ( - int(kwargs["scale-x"]), - int(kwargs["scale-y"]), - int(kwargs["scale-z"]), - ) - magnitude_fraction = kwargs["magnitude"] - magnitude = ( - MAX_MAGNITUDE - MIN_MAGNITUDE - ) * magnitude_fraction + MIN_MAGNITUDE + struct = bs.structure + total_repeat_cell_cnt = 1 + # update structure if the controls got triggered + if sueprcell_update: + total_repeat_cell_cnt = scale_x * scale_y * scale_z - # create supercell - trans = SupercellTransformation( - ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) - ) - struct = trans.apply_transformation(bs.structure) + # create supercell + trans = SupercellTransformation( + ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) + ) + struct = trans.apply_transformation(struct) struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) scene = struc_graph.get_scene( @@ -931,12 +1738,17 @@ def update_crystal_animation(cd, bs, update, kwargs): pt = cd["points"][0] qpoint, band_num = pt.get("customdata", [0, 0]) - return PhononBandstructureAndDosComponent._get_eigendisplacement( + # magnitude + magnitude = ( + MAX_MAGNITUDE - MIN_MAGNITUDE + ) * magnitude_fraction + MIN_MAGNITUDE + + return PhononBandstructureAndDosComponent_v2._get_eigendisplacement( ph_bs=bs, json_data=json_data, band=band_num, qpoint=qpoint, - total_repeat_cell_cnt=scale_x * scale_y * scale_z, + total_repeat_cell_cnt=total_repeat_cell_cnt, magnitude=magnitude, ) From afde7883e43389f276b802cf37704deab1478104 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Mon, 27 Oct 2025 16:27:19 -0700 Subject: [PATCH 18/24] update input and make it more clear --- crystal_toolkit/components/phonon.py | 51 ++++++++-------------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 6231ceaf..dde0fb2f 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1661,42 +1661,10 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): Input(self.id("ph-bsdos-graph"), "clickData"), Input(self.id("ph_bs"), "data"), Input(self.id("supercell-controls-btn"), "n_clicks"), - Input( - { - "component_id": "CTPhononSection_phonon_bs", - "kwarg_label": "magnitude", - "idx": "False", - "hint": "slider", - }, - "value", - ), - State( - { - "component_id": "CTPhononSection_phonon_bs", - "kwarg_label": "scale-x", - "idx": "()", - "hint": "()", - }, - "value", - ), - State( - { - "component_id": "CTPhononSection_phonon_bs", - "kwarg_label": "scale-y", - "idx": "()", - "hint": "()", - }, - "value", - ), - State( - { - "component_id": "CTPhononSection_phonon_bs", - "kwarg_label": "scale-z", - "idx": "()", - "hint": "()", - }, - "value", - ), + Input(self.get_kwarg_id("magnitude"), "value"), + State(self.get_kwarg_id("scale-x"), "value"), + State(self.get_kwarg_id("scale-y"), "value"), + State(self.get_kwarg_id("scale-z"), "value"), # prevent_initial_call=True ) def update_crystal_animation( @@ -1708,6 +1676,15 @@ def update_crystal_animation( if not bs: raise PreventUpdate + # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values. + # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value, + # this approach provides better clarity and readability. + kwargs = self.reconstruct_kwargs_from_state() + magnitude_fraction = kwargs.get("magnitude") + scale_x = kwargs.get("scale-x") + scale_y = kwargs.get("scale-y") + scale_z = kwargs.get("scale-z") + if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) @@ -1738,6 +1715,8 @@ def update_crystal_animation( pt = cd["points"][0] qpoint, band_num = pt.get("customdata", [0, 0]) + print(scale_x, scale_y, scale_z) + print(magnitude_fraction) # magnitude magnitude = ( MAX_MAGNITUDE - MIN_MAGNITUDE From e10b27d3be76524c24c6f3cb6483d89284287021 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Mon, 27 Oct 2025 16:50:55 -0700 Subject: [PATCH 19/24] remove print --- crystal_toolkit/components/phonon.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index dde0fb2f..ac3e54f7 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1715,8 +1715,6 @@ def update_crystal_animation( pt = cd["points"][0] qpoint, band_num = pt.get("customdata", [0, 0]) - print(scale_x, scale_y, scale_z) - print(magnitude_fraction) # magnitude magnitude = ( MAX_MAGNITUDE - MIN_MAGNITUDE From e6e3d8dc1262653645da31d1c19bdf0f077da8cd Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Tue, 28 Oct 2025 11:06:43 -0700 Subject: [PATCH 20/24] make sure index is integer --- crystal_toolkit/components/phonon.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index ac3e54f7..25c00614 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1082,8 +1082,9 @@ def calc_max_displacement(idx: int) -> list: # get the atom index assert total_repeat_cell_cnt != 0 + modified_idx = ( - (idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx + int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx ) return [ From a8133c78631b05161fc007513bc73ee962b29fd0 Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Wed, 5 Nov 2025 12:30:49 -0800 Subject: [PATCH 21/24] add a dedicated function for access control in web --- crystal_toolkit/components/phonon.py | 784 +-------------------------- 1 file changed, 7 insertions(+), 777 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 25c00614..5ef76d90 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -73,778 +73,6 @@ def __init__( self.create_store("bs", None) self.create_store("dos", None) - @property - def _sub_layouts(self) -> dict[str, Component]: - # defaults - state = {"label-select": "sc", "dos-select": "ap"} - - fig = PhononBandstructureAndDosComponent.get_figure(None, None) - # Main plot - graph = dcc.Graph( - figure=fig, - config={"displayModeBar": False}, - responsive=False, - id=self.id("ph-bsdos-graph"), - ) - - # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(None) - zone = CrystalToolkitScene( - data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") - ) - - # Hide by default if not loaded by mpid, switching between k-paths - # on-the-fly only supported for bandstructures retrieved from MP - show_path_options = bool(self.initial_data["default"]["mpid"]) - - options = [ - {"label": "Latimer-Munro", "value": "lm"}, - {"label": "Hinuma et al.", "value": "hin"}, - {"label": "Setyawan-Curtarolo", "value": "sc"}, - ] - # Convention selection for band structure - convention = html.Div( - [ - self.get_choice_input( - kwarg_label="path-convention", - state=state, - label="Path convention", - help_str="Convention to choose path in k-space", - options=options, - ) - ], - style=( - {"width": "200px"} - if show_path_options - else {"maxWidth": "200", "display": "none"} - ), - id=self.id("path-container"), - ) - - # Equivalent labels across band structure conventions - label_select = html.Div( - [ - self.get_choice_input( - kwarg_label="label-select", - state=state, - label="Label convention", - help_str="Convention to choose labels for path in k-space", - options=options, - ) - ], - style=( - {"width": "200px"} - if show_path_options - else {"width": "200px", "display": "none"} - ), - id=self.id("label-container"), - ) - - # Density of states data selection - dos_select = self.get_choice_input( - kwarg_label="dos-select", - state=state, - label="Projection", - help_str="Choose projection", - options=[{"label": "Atom Projected", "value": "ap"}], - style={"width": "200px"}, - ) - - summary_dict = self._get_data_list_dict(None, None) - summary_table = get_data_list(summary_dict) - - # crystal visualization - - tip = html.P( - "Click different q-points and bands in the dispersion diagram to see the crystal vibration.", - id=self.id("crystal-tip"), - style={ - "margin": "0 0 12px", - "fontSize": "16px", - "color": "#555", - "textAlign": "center", - }, - ) - - crystal_animation = CrystalToolkitAnimationScene( - data={}, - sceneSize="200px", - id=self.id("crystal-animation"), - settings={"defaultZoom": 1.5}, - axisView="SW", - showControls=False, # disable download for now - ) - - return { - "graph": graph, - "convention": convention, - "dos-select": dos_select, - "label-select": label_select, - "zone": zone, - "table": summary_table, - "crystal-animation": crystal_animation, - "tip": tip, - } - - def layout(self) -> html.Div: - sub_layouts = self._sub_layouts - crystal_animation = Columns( - [Column([sub_layouts["tip"], sub_layouts["crystal-animation"]])] - ) - graph = Columns([Column([sub_layouts["graph"]])]) - controls = Columns( - [ - Column( - [ - sub_layouts["convention"], - sub_layouts["label-select"], - sub_layouts["dos-select"], - ] - ) - ] - ) - brillouin_zone = Columns( - [ - Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), - Column([Label("Brillouin Zone"), sub_layouts["zone"]]), - ] - ) - - return html.Div([graph, crystal_animation, controls, brillouin_zone]) - - @staticmethod - def _get_eigendisplacement( - ph_bs: BandStructureSymmLine, - json_data: dict, - band: int = 0, - qpoint: int = 0, - precision: int = 15, - magnitude: int = MAX_MAGNITUDE / 2, - ) -> dict: - if not ph_bs or not json_data: - return {} - - assert json_data["contents"][0]["name"] == "atoms" - assert json_data["contents"][1]["name"] == "bonds" - rdata = deepcopy(json_data) - - def calc_max_displacement(idx: int) -> list: - """ - Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement. - - Parameters: - idx (int): The atom index. - - Returns: - list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement] - - This function extracts the real component of the atom's eigendisplacement, - scales it by the specified magnitude, and returns the resulting vector. - """ - return [ - round(complex(vec).real * magnitude, precision) - for vec in ph_bs.eigendisplacements[band][qpoint][idx] - ] - - def calc_animation_step(max_displacement: list, coef: int) -> list: - """ - Calculate the displacement for an animation frame based on the given coefficient. - - Parameters: - max_displacement (list): A list of maximum displacements along each axis, - formatted as [x_max_displacement, y_max_displacement, z_max_displacement]. - coef (int): A coefficient indicating the motion direction. - - 0: no movement - - 1: forward movement - - -1: backward movement - - Returns: - list: The displacement vector [x_displacement, y_displacement, z_displacement]. - - This function generates oscillatory motion by scaling the maximum displacement - with the provided coefficient. - """ - return [round(coef * md, precision) for md in max_displacement] - - # Compute per-frame atomic motion. - # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates. - contents0 = json_data["contents"][0]["contents"] - for cidx, content in enumerate(contents0): - max_displacement = calc_max_displacement(content["_meta"][0]) - rcontent = rdata["contents"][0]["contents"][cidx] - # put animation frame to the given atom index - rcontent["animate"] = [ - calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF - ] - rcontent["keyframes"] = list(range(len(DISPLACE_COEF))) - rcontent["animateType"] = "displacement" - # Compute per-frame bonding motion. - # Explanation: - # Each bond connects two atoms, `u` and `v`, represented as (u)----(v) - # To model the bond motion, it is divided into two segments: - # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v) - # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid). - # For each cylinder, displacements are assigned to the endpoints — for example, - # the (u)--(mid) cylinder uses: - # [ - # [u_x_displacement, u_y_displacement, u_z_displacement], - # [mid_x_displacement, mid_y_displacement, mid_z_displacement] - # ]. - contents1 = json_data["contents"][1]["contents"] - - for cidx, content in enumerate(contents1): - bond_animation = [] - assert len(content["_meta"]) == len(content["positionPairs"]) - - for atom_idx_pair in content["_meta"]: - max_displacements = list( - map(calc_max_displacement, atom_idx_pair) - ) # max displacement for u and v - - u_to_middle_bond_animation = [] - - for coef in DISPLACE_COEF: - # Calculate the midpoint displacement between atom u and v for each animation frame. - u_displacement, v_displacement = [ - np.array(calc_animation_step(max_displacement, coef)) - for max_displacement in max_displacements - ] - middle_end_displacement = np.add(u_displacement, v_displacement) / 2 - - u_to_middle_bond_animation.append( - [ - u_displacement, # u atom displacement - [ - round(dis, precision) for dis in middle_end_displacement - ], # middle point displacement - ] - ) - - bond_animation.append(u_to_middle_bond_animation) - - rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation - rdata["contents"][1]["contents"][cidx]["keyframes"] = list( - range(len(DISPLACE_COEF)) - ) - rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" - - # remove unused sense - for i in range(2, 4): - rdata["contents"][i]["visible"] = False - - return rdata - - @staticmethod - def _get_ph_bs_dos( - data: dict[str, Any] | None, - ) -> tuple[PhononBandStructureSymmLine, CompletePhononDos]: - data = data or {} - - # this component can be loaded either from mpid or - # directly from BandStructureSymmLine or CompleteDos objects - # if mpid is supplied, it takes precedence - - mpid = data.get("mpid") - bandstructure_symm_line = data.get("bandstructure_symm_line") - density_of_states = data.get("density_of_states") - - if not mpid and (bandstructure_symm_line is None or density_of_states is None): - return None, None - - if mpid: - with MPRester() as mpr: - try: - bandstructure_symm_line = ( - mpr.get_phonon_bandstructure_by_material_id(mpid) - ) - except Exception as exc: - print(exc) - bandstructure_symm_line = None - - try: - density_of_states = mpr.get_phonon_dos_by_material_id(mpid) - except Exception as exc: - print(exc) - density_of_states = None - - else: - if bandstructure_symm_line and isinstance(bandstructure_symm_line, dict): - bandstructure_symm_line = PhononBandStructureSymmLine.from_dict( - bandstructure_symm_line - ) - - if density_of_states and isinstance(density_of_states, dict): - density_of_states = CompletePhononDos.from_dict(density_of_states) - - return bandstructure_symm_line, density_of_states - - @staticmethod - def get_brillouin_zone_scene(bs: PhononBandStructureSymmLine) -> Scene: - if not bs: - return Scene(name="brillouin_zone", contents=[]) - - # TODO: from BSPlotter, merge back into BSPlotter - # Brillouin zone - bz_lattice = bs.structure.lattice.reciprocal_lattice - bz = bz_lattice.get_wigner_seitz_cell() - lines = [] - for iface in range(len(bz)): # pylint: disable=C0200 - for line in itertools.combinations(bz[iface], 2): - for jface in range(len(bz)): - if ( - iface < jface - and any(np.all(line[0] == x) for x in bz[jface]) - and any(np.all(line[1] == x) for x in bz[jface]) - ): - lines += [list(line[0]), list(line[1])] - - zone_lines = Lines(positions=lines) - zone_surface = Convex(positions=lines, opacity=0.05, color="#000000") - - labels = {} - for qpt in bs.qpoints: - if qpt.label: - label = qpt.label - for orig, new in pretty_labels.items(): - label = label.replace(orig, new) - labels[label] = bz_lattice.get_cartesian_coords(qpt.frac_coords) - label_list = [ - Spheres(positions=[coords], tooltip=label, radius=0.03, color="#5EB1BF") - for label, coords in labels.items() - ] - - path = [] - cylinder_pairs = [] - for b in bs.branches: - start = bz_lattice.get_cartesian_coords( - bs.qpoints[b["start_index"]].frac_coords - ) - end = bz_lattice.get_cartesian_coords( - bs.qpoints[b["end_index"]].frac_coords - ) - path += [start, end] - cylinder_pairs += [[start, end]] - # path_lines = Lines(positions=path, color="#ff4b5c") - path_lines = Cylinders( - positionPairs=cylinder_pairs, color="#5EB1BF", radius=0.01 - ) - ibz_region = Convex(positions=path, opacity=0.2, color="#5EB1BF") - - contents = [zone_lines, zone_surface, path_lines, ibz_region, *label_list] - - return Scene(name="brillouin_zone", contents=contents) - - @staticmethod - def get_ph_bandstructure_traces(bs, freq_range): - bs_reg_plot = PhononBSPlotter(bs) - - bs_data = bs_reg_plot.bs_plot_data() - - bands = [] - for band_num in range(bs.nb_bands): - for segment in bs_data["frequency"]: - if any(v <= freq_range[1] for v in segment[band_num]) and any( - v >= freq_range[0] for v in segment[band_num] - ): - bands.append(band_num) # noqa: PERF401 - - bs_traces = [] - - for d, dist_val in enumerate(bs_data["distances"]): - x_dat = dist_val - - traces_for_segment = [] - - segment = bs_data["frequency"][d] - - traces_for_segment += [ - { - "x": x_dat, - "y": segment[band_num], - "mode": "lines", - "line": {"color": "#1f77b4"}, - "hoverinfo": "skip", - "name": "Total", - "customdata": [[di, band_num] for di in range(len(x_dat))], - "hovertemplate": "%{y:.2f} THz", - "showlegend": False, - "xaxis": "x", - "yaxis": "y", - } - for band_num in bands - ] - - bs_traces += traces_for_segment - - for entry_num in range(len(bs_data["ticks"]["label"])): - for key in pretty_labels: - if key in bs_data["ticks"]["label"][entry_num]: - bs_data["ticks"]["label"][entry_num] = bs_data["ticks"]["label"][ - entry_num - ].replace(key, pretty_labels[key]) - - # Vertical lines for disjointed segments - for dist_val, tick_label in zip( - bs_data["ticks"]["distance"], bs_data["ticks"]["label"] - ): - vert_trace = [ - { - "x": [dist_val, dist_val], - "y": freq_range, - "mode": "lines", - "marker": { - "color": "#F5F5F5" if "|" not in tick_label else "white" - }, - "line": {"width": 0.5 if "|" not in tick_label else 2}, - "hoverinfo": "skip", - "showlegend": False, - "xaxis": "x", - "yaxis": "y", - } - ] - - bs_traces += vert_trace - - return bs_traces, bs_data - - @staticmethod - def _get_data_list_dict( - bs: PhononBandStructureSymmLine, dos: CompletePhononDos - ) -> dict[str, str | bool | int]: - if (not bs) and (not dos): - return {} - - bs_minpoint, bs_min_freq = bs.min_freq() - min_freq_report = ( - f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" - ) - if bs_minpoint.label is not None: - label = f" ({bs_minpoint.label})" - for orig, new in pretty_labels.items(): - label = label.replace(orig, new) - min_freq_report += label - - f" at q-point=${bs_minpoint.label}$ (frac. coords. = {bs_minpoint.frac_coords})" - - summary_dict: dict[str, str | bool | int] = { - "Number of bands": f"{bs.nb_bands:,}", - "Number of q-points": f"{bs.nb_qpoints:,}", - # for NAC see https://phonopy.github.io/phonopy/formulation.html#non-analytical-term-correction - Label( - [ - "Has ", - html.A( - "NAC", - href="https://phonopy.github.io/phonopy/formulation.html#non-analytical-term-correction", - target="blank", - ), - ] - ): ("Yes" if bs.has_nac else "No"), - "Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No", - "Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No", - "Min frequency": min_freq_report, - "max frequency": f"{max(dos.frequencies):.2f} THz", - } - - return summary_dict - - @staticmethod - def get_ph_dos_traces(dos: CompletePhononDos, freq_range: tuple[float, float]): - dos_traces = [] - - dos_max = np.abs(dos.frequencies - freq_range[1]).argmin() - dos_min = np.abs(dos.frequencies - freq_range[0]).argmin() - - tdos_label = "Total DOS" - - # Total DOS - trace_tdos = { - "x": dos.densities[dos_min:dos_max], - "y": dos.frequencies[dos_min:dos_max], - "mode": "lines", - "name": tdos_label, - "line": go.scatter.Line(color="#444444"), - "fill": "tozerox", - "fillcolor": "#C4C4C4", - "legendgroup": "spinup", - "xaxis": "x2", - "yaxis": "y2", - } - - dos_traces.append(trace_tdos) - - # Projected DOS - if isinstance(dos, CompletePhononDos): - colors = [ - "#d62728", # brick red - "#2ca02c", # cooked asparagus green - "#17becf", # blue-teal - "#bcbd22", # curry yellow-green - "#9467bd", # muted purple - "#8c564b", # chestnut brown - "#e377c2", # raspberry yogurt pink - ] - - ele_dos = dos.get_element_dos() # project DOS onto elements - for count, label in enumerate(ele_dos): - spin_up_label = str(label) - - trace = { - "x": ele_dos[label].densities[dos_min:dos_max], - "y": dos.frequencies[dos_min:dos_max], - "mode": "lines", - "name": spin_up_label, - "line": dict(width=2, color=colors[count]), - "xaxis": "x2", - "yaxis": "y2", - } - - dos_traces.append(trace) - - return dos_traces - - @staticmethod - def get_figure( - ph_bs: PhononBandStructureSymmLine | None = None, - ph_dos: CompletePhononDos | None = None, - freq_range: tuple[float | None, float | None] = (None, None), - ) -> go.Figure: - if (not ph_dos) and (not ph_bs): - empty_plot_style = { - "height": 500, - "xaxis": {"visible": False}, - "yaxis": {"visible": False}, - "paper_bgcolor": "rgba(0,0,0,0)", - "plot_bgcolor": "rgba(0,0,0,0)", - } - - return go.Figure(layout=empty_plot_style) - - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - - if ph_bs: - ( - bs_traces, - bs_data, - ) = PhononBandstructureAndDosComponent.get_ph_bandstructure_traces( - ph_bs, freq_range=freq_range - ) - - if ph_dos: - dos_traces = PhononBandstructureAndDosComponent.get_ph_dos_traces( - ph_dos, freq_range=freq_range - ) - - # TODO: add logic to handle if bs_traces and/or dos_traces not present - - rmax_list = [ - max(dos_traces[0]["x"]), - abs(min(dos_traces[0]["x"])), - ] - if len(dos_traces) > 1 and "x" in dos_traces[1] and dos_traces[1]["x"].any(): - rmax_list += [ - max(dos_traces[1]["x"]), - abs(min(dos_traces[1]["x"])), - ] - - rmax = max(rmax_list) - - # -- Add trace data to plots - - in_common_axis_styles = dict( - gridcolor="white", - linecolor="rgb(71,71,71)", - linewidth=2, - showgrid=False, - showline=True, - tickfont=dict(size=16), - ticks="inside", - tickwidth=2, - ) - - xaxis_style = dict( - **in_common_axis_styles, - tickmode="array", - mirror=True, - range=[0, bs_data["ticks"]["distance"][-1]], - ticktext=bs_data["ticks"]["label"], - tickvals=bs_data["ticks"]["distance"], - title=dict(text="Wave Vector", font=dict(size=16)), - zeroline=False, - ) - - yaxis_style = dict( - **in_common_axis_styles, - mirror="ticks", - range=freq_range, - title=dict(text="Frequency (THz)", font=dict(size=16)), - zeroline=True, - zerolinecolor="white", - zerolinewidth=2, - ) - - xaxis_style_dos = dict( - **in_common_axis_styles, - title=dict(text="Density of States", font=dict(size=16)), - zeroline=False, - mirror=True, - range=[0, rmax * 1.1], - zerolinecolor="white", - zerolinewidth=2, - ) - - yaxis_style_dos = dict( - **in_common_axis_styles, - zeroline=True, - showticklabels=False, - mirror="ticks", - zerolinewidth=2, - range=freq_range, - zerolinecolor="white", - matches="y", - anchor="x2", - ) - - layout = dict( - title="", - xaxis1=xaxis_style, - xaxis2=xaxis_style_dos, - yaxis=yaxis_style, - yaxis2=yaxis_style_dos, - showlegend=True, - height=500, - width=1000, - hovermode="closest", - paper_bgcolor="rgba(0,0,0,0)", - plot_bgcolor="rgba(230,230,230,230)", - margin=dict(l=60, b=50, t=50, pad=0, r=30), - clickmode="event+select", - ) - - figure = {"data": bs_traces + dos_traces, "layout": layout} - - legend = dict( - x=1.02, - y=1.005, - xanchor="left", - yanchor="top", - bordercolor="#333", - borderwidth=2, - traceorder="normal", - ) - - figure["layout"]["legend"] = legend - - figure["layout"]["xaxis1"]["domain"] = [0.0, 0.7] - figure["layout"]["xaxis2"]["domain"] = [0.73, 1.0] - - return figure - - def generate_callbacks(self, app, cache) -> None: - @app.callback( - Output(self.id("ph-bsdos-graph"), "figure"), - Output(self.id("zone"), "data"), - Output(self.id("table"), "children"), - Input(self.id("ph_bs"), "data"), - Input(self.id("ph_dos"), "data"), - ) - def update_graph(bs, dos): - if isinstance(bs, dict): - bs = PhononBandStructureSymmLine.from_dict(bs) - if isinstance(dos, dict): - dos = CompletePhononDos.from_dict(dos) - - figure = self.get_figure(bs, dos) - - zone_scene = self.get_brillouin_zone_scene(bs) - - summary_dict = self._get_data_list_dict(bs, dos) - summary_table = get_data_list(summary_dict) - - return figure, zone_scene.to_json(), summary_table - - @app.callback( - Output(self.id("brillouin-zone"), "data"), - Input(self.id("ph-bsdos-graph"), "hoverData"), - Input(self.id("ph-bsdos-graph"), "clickData"), - ) - def highlight_bz_on_hover_bs(hover_data, click_data, label_select): - """Highlight the corresponding point/edge of the Brillouin Zone when hovering the band - structure plot. - """ - # TODO: figure out what to return (CSS?) to highlight BZ edge/point - return - - @app.callback( - Output(self.id("crystal-animation"), "data"), - Input(self.id("ph-bsdos-graph"), "clickData"), - Input(self.id("ph_bs"), "data"), - # prevent_initial_call=True - ) - def update_crystal_animation(cd, bs): - if not bs: - raise PreventUpdate - - if isinstance(bs, dict): - bs = PhononBandStructureSymmLine.from_dict(bs) - - struc_graph = StructureGraph.from_local_env_strategy( - bs.structure, CrystalNN() - ) - scene = struc_graph.get_scene( - draw_image_atoms=False, - bonded_sites_outside_unit_cell=False, - site_get_scene_kwargs={"retain_atom_idx": True}, - ) - json_data = scene.to_json() - - qpoint = 0 - band_num = 0 - - if cd and cd.get("points"): - pt = cd["points"][0] - qpoint, band_num = pt.get("customdata", [0, 0]) - - return PhononBandstructureAndDosComponent._get_eigendisplacement( - ph_bs=bs, - json_data=json_data, - band=band_num, - qpoint=qpoint, - ) - - -class PhononBandstructureAndDosComponent_v2(MPComponent): - def __init__( - self, - mpid: str | None = None, - bandstructure_symm_line: BandStructureSymmLine | None = None, - density_of_states: CompleteDos | None = None, - id: str | None = None, - **kwargs, - ) -> None: - # this is a compound component, can be fed by mpid or - # by the BandStructure itself - super().__init__( - id=id, - default_data={ - "mpid": mpid, - "bandstructure_symm_line": bandstructure_symm_line, - "density_of_states": density_of_states, - }, - **kwargs, - ) - - bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - self.create_store("bs-store", bs) - self.create_store("bs", None) - self.create_store("dos", None) - @property def _sub_layouts(self) -> dict[str, Component]: # defaults @@ -1010,13 +238,12 @@ def _sub_layouts(self) -> dict[str, Component]: "crystal-animation-controls": crystal_animation_controls, } - def layout(self) -> html.Div: + def _get_animation_panel(self): sub_layouts = self._sub_layouts - crystal_animation = Columns( + return Columns( [ Column( [ - # sub_layouts["tip"], Columns( [ sub_layouts["crystal-animation"], @@ -1025,9 +252,12 @@ def layout(self) -> html.Div: ) ] ), - # Column([sub_layouts["crystal-animation-controls"]]) ] ) + + def layout(self) -> html.Div: + sub_layouts = self._sub_layouts + crystal_animation = self._get_animation_panel() graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -1721,7 +951,7 @@ def update_crystal_animation( MAX_MAGNITUDE - MIN_MAGNITUDE ) * magnitude_fraction + MIN_MAGNITUDE - return PhononBandstructureAndDosComponent_v2._get_eigendisplacement( + return PhononBandstructureAndDosComponent._get_eigendisplacement( ph_bs=bs, json_data=json_data, band=band_num, From b31a1f607afbfd8a86ebde77e16915fa1b7babce Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Wed, 5 Nov 2025 12:32:10 -0800 Subject: [PATCH 22/24] remove PhononBandstructureAndDosComponent_v2 --- crystal_toolkit/components/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/crystal_toolkit/components/__init__.py b/crystal_toolkit/components/__init__.py index d905e2aa..1b2cc6a6 100644 --- a/crystal_toolkit/components/__init__.py +++ b/crystal_toolkit/components/__init__.py @@ -14,7 +14,6 @@ ) from crystal_toolkit.components.phonon import ( PhononBandstructureAndDosComponent, - PhononBandstructureAndDosComponent_v2, PhononBandstructureAndDosPanelComponent, ) from crystal_toolkit.components.pourbaix import PourbaixDiagramComponent From 3331e3d57e3fbba854788ba75bdd294f556ce2ca Mon Sep 17 00:00:00 2001 From: Chiu Peter Date: Wed, 5 Nov 2025 12:37:31 -0800 Subject: [PATCH 23/24] pre rebase --- crystal_toolkit/components/phonon.py | 163 ++++++++++++++------------- 1 file changed, 83 insertions(+), 80 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 962764e0..5ef76d90 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -7,7 +7,7 @@ import numpy as np import plotly.graph_objects as go from dash import dcc, html -from dash.dependencies import Component, Input, Output +from dash.dependencies import Component, Input, Output, State from dash.exceptions import PreventUpdate from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene @@ -34,7 +34,7 @@ MARKER_COLOR = "red" MARKER_SIZE = 12 MARKER_SHAPE = "x" -MAX_MAGNITUDE = 400 +MAX_MAGNITUDE = 300 MIN_MAGNITUDE = 0 # TODOs: @@ -165,6 +165,8 @@ def _sub_layouts(self) -> dict[str, Component]: sceneSize="500px", id=self.id("crystal-animation"), settings={"defaultZoom": 1.2}, + axisView="SW", + showControls=False, # disable download for now ), style={"width": "60%"}, ) @@ -184,6 +186,7 @@ def _sub_layouts(self) -> dict[str, Component]: default=1, is_int=True, label="x", + min=1, style={"width": "5rem"}, ), self.get_numerical_input( @@ -191,6 +194,7 @@ def _sub_layouts(self) -> dict[str, Component]: default=1, is_int=True, label="y", + min=1, style={"width": "5rem"}, ), self.get_numerical_input( @@ -198,11 +202,12 @@ def _sub_layouts(self) -> dict[str, Component]: default=1, is_int=True, label="z", + min=1, style={"width": "5rem"}, ), html.Button( "Update", - id=self.id("controls-btn"), + id=self.id("supercell-controls-btn"), style={"height": "40px"}, ), ], @@ -233,13 +238,12 @@ def _sub_layouts(self) -> dict[str, Component]: "crystal-animation-controls": crystal_animation_controls, } - def layout(self) -> html.Div: + def _get_animation_panel(self): sub_layouts = self._sub_layouts - crystal_animation = Columns( + return Columns( [ Column( [ - # sub_layouts["tip"], Columns( [ sub_layouts["crystal-animation"], @@ -248,9 +252,12 @@ def layout(self) -> html.Div: ) ] ), - # Column([sub_layouts["crystal-animation-controls"]]) ] ) + + def layout(self) -> html.Div: + sub_layouts = self._sub_layouts + crystal_animation = self._get_animation_panel() graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -279,8 +286,8 @@ def _get_eigendisplacement( band: int = 0, qpoint: int = 0, precision: int = 15, - magnitude: int = 225, - total_repeat_cell_cnt: int | None = None, + magnitude: int = MAX_MAGNITUDE / 2, + total_repeat_cell_cnt: int = 1, ) -> dict: if not ph_bs or not json_data: return {} @@ -305,8 +312,9 @@ def calc_max_displacement(idx: int) -> list: # get the atom index assert total_repeat_cell_cnt != 0 + modified_idx = ( - (idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx + int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx ) return [ @@ -793,27 +801,7 @@ def get_figure( clickmode="event+select", ) - default_red_dot = [ - { - "type": "scatter", - "mode": "markers", - "x": [0], - "y": [0], - "marker": { - "color": MARKER_COLOR, - "size": MARKER_SIZE, - "symbol": MARKER_SHAPE, - }, - "name": "click-marker", - "showlegend": False, - "customdata": [[0, 0]], - "hovertemplate": ( - "band: %{customdata[1]}
q-point: %{customdata[0]}
" - ), - } - ] - - figure = {"data": bs_traces + dos_traces + default_red_dot, "layout": layout} + figure = {"data": bs_traces + dos_traces, "layout": layout} legend = dict( x=1.02, @@ -848,37 +836,37 @@ def update_graph(bs, dos, nclick): dos = CompletePhononDos.from_dict(dos) figure = self.get_figure(bs, dos) - if nclick and nclick.get("points"): - # remove marker if there is one - figure["data"] = [ - t for t in figure["data"] if t.get("name") != "click-marker" - ] - x_click = nclick["points"][0]["x"] - y_click = nclick["points"][0]["y"] + # remove marker if there is one + figure["data"] = [ + t for t in figure["data"] if t.get("name") != "click-marker" + ] - pt = nclick["points"][0] - qpoint, band_num = pt.get("customdata", [0, 0]) + x_click = nclick["points"][0]["x"] if nclick else 0 + y_click = nclick["points"][0]["y"] if nclick else 0 + pt = nclick["points"][0] if nclick else {} - figure["data"].append( - { - "type": "scatter", - "mode": "markers", - "x": [x_click], - "y": [y_click], - "marker": { - "color": MARKER_COLOR, - "size": MARKER_SIZE, - "symbol": MARKER_SHAPE, - }, - "name": "click-marker", - "showlegend": False, - "customdata": [[qpoint, band_num]], - "hovertemplate": ( - "band: %{customdata[1]}
q-point: %{customdata[0]}
" - ), - } - ) + qpoint, band_num = pt.get("customdata", [0, 0]) + + figure["data"].append( + { + "type": "scatter", + "mode": "markers", + "x": [x_click], + "y": [y_click], + "marker": { + "color": MARKER_COLOR, + "size": MARKER_SIZE, + "symbol": MARKER_SHAPE, + }, + "name": "click-marker", + "showlegend": False, + "customdata": [[qpoint, band_num]], + "hovertemplate": ( + "band: %{customdata[1]}
q-point: %{customdata[0]}
" + ), + } + ) zone_scene = self.get_brillouin_zone_scene(bs) @@ -903,35 +891,45 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): Output(self.id("crystal-animation"), "data"), Input(self.id("ph-bsdos-graph"), "clickData"), Input(self.id("ph_bs"), "data"), - Input(self.id("controls-btn"), "n_clicks"), - Input(self.get_all_kwargs_id(), "value"), + Input(self.id("supercell-controls-btn"), "n_clicks"), + Input(self.get_kwarg_id("magnitude"), "value"), + State(self.get_kwarg_id("scale-x"), "value"), + State(self.get_kwarg_id("scale-y"), "value"), + State(self.get_kwarg_id("scale-z"), "value"), # prevent_initial_call=True ) - def update_crystal_animation(cd, bs, update, kwargs): + def update_crystal_animation( + cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z + ): + # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields, + # ensuring updates occur only after the `supercell-controls-btn`` is clicked. + if not bs: raise PreventUpdate + # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values. + # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value, + # this approach provides better clarity and readability. + kwargs = self.reconstruct_kwargs_from_state() + magnitude_fraction = kwargs.get("magnitude") + scale_x = kwargs.get("scale-x") + scale_y = kwargs.get("scale-y") + scale_z = kwargs.get("scale-z") + if isinstance(bs, dict): bs = PhononBandStructureSymmLine.from_dict(bs) - kwargs = self.reconstruct_kwargs_from_state() + struct = bs.structure + total_repeat_cell_cnt = 1 + # update structure if the controls got triggered + if sueprcell_update: + total_repeat_cell_cnt = scale_x * scale_y * scale_z - # animation control - scale_x, scale_y, scale_z = ( - int(kwargs["scale-x"]), - int(kwargs["scale-y"]), - int(kwargs["scale-z"]), - ) - magnitude_fraction = kwargs["magnitude"] - magnitude = ( - MAX_MAGNITUDE - MIN_MAGNITUDE - ) * magnitude_fraction + MIN_MAGNITUDE - - # create supercell - trans = SupercellTransformation( - ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) - ) - struct = trans.apply_transformation(bs.structure) + # create supercell + trans = SupercellTransformation( + ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) + ) + struct = trans.apply_transformation(struct) struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) scene = struc_graph.get_scene( @@ -948,12 +946,17 @@ def update_crystal_animation(cd, bs, update, kwargs): pt = cd["points"][0] qpoint, band_num = pt.get("customdata", [0, 0]) + # magnitude + magnitude = ( + MAX_MAGNITUDE - MIN_MAGNITUDE + ) * magnitude_fraction + MIN_MAGNITUDE + return PhononBandstructureAndDosComponent._get_eigendisplacement( ph_bs=bs, json_data=json_data, band=band_num, qpoint=qpoint, - total_repeat_cell_cnt=scale_x * scale_y * scale_z, + total_repeat_cell_cnt=total_repeat_cell_cnt, magnitude=magnitude, ) From baf6393656420b3f5e935ed0f600ed458b7a09ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Nov 2025 22:23:09 +0000 Subject: [PATCH 24/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- crystal_toolkit/components/phonon.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 1a521a4f..3348d9b7 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -18,14 +18,7 @@ from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import ( - Column, - Columns, - Label, - MessageBody, - MessageContainer, - get_data_list, -) +from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: