From 1cabfde5191e4af151a6117fd566ca9d552dd573 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 12 Apr 2025 18:45:34 -0400 Subject: [PATCH 01/12] fea: swap to use nn.embedding for roost. --- aviary/core.py | 7 +- aviary/roost/data.py | 32 +-- aviary/roost/model.py | 23 +- examples/notebooks/Roost.ipynb | 423 ++++++++++++++++++++++++++++++++- examples/roost-example.py | 13 +- tests/test_roost.py | 8 +- 6 files changed, 448 insertions(+), 58 deletions(-) diff --git a/aviary/core.py b/aviary/core.py index 983d41a9..a023f022 100644 --- a/aviary/core.py +++ b/aviary/core.py @@ -32,6 +32,7 @@ def __init__( epoch: int = 0, device: str | None = None, best_val_scores: dict[str, float] | None = None, + **kwargs, ) -> None: """Store core model parameters. @@ -47,6 +48,7 @@ def __init__( device (str, optional): Device to store the model parameters on. best_val_scores (dict[str, float], optional): Validation score to use for early stopping. Defaults to None. + **kwargs: Additional keyword arguments. """ super().__init__() self.task_dict = task_dict @@ -299,8 +301,9 @@ def evaluate( preds = output.squeeze(1) loss = loss_func(preds, targets) - z_scored_error = preds - targets - error = normalizer.std * z_scored_error.data.cpu() + denormed_preds = normalizer.denorm(preds) + denormed_targets = normalizer.denorm(targets) + error = denormed_preds - denormed_targets target_metrics["MAE"].append(float(error.abs().mean())) target_metrics["MSE"].append(float(error.pow(2).mean())) diff --git a/aviary/roost/data.py b/aviary/roost/data.py index 3f84abb0..e5803f92 100644 --- a/aviary/roost/data.py +++ b/aviary/roost/data.py @@ -1,4 +1,3 @@ -import json from collections.abc import Sequence from functools import cache from typing import Any @@ -10,8 +9,6 @@ from torch import LongTensor, Tensor from torch.utils.data import Dataset -from aviary import PKG_DIR - class CompositionData(Dataset): """Dataset class for the Roost composition model.""" @@ -20,7 +17,6 @@ def __init__( self, df: pd.DataFrame, task_dict: dict[str, str], - elem_embedding: str = "matscholar200", inputs: str = "composition", identifiers: Sequence[str] = ("material_id", "composition"), ): @@ -47,14 +43,6 @@ def __init__( self.identifiers = list(identifiers) self.df = df - if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as file: - self.elem_features = json.load(file) - - self.elem_emb_len = len(next(iter(self.elem_features.values()))) - self.n_targets = [] for target, task in self.task_dict.items(): if task == "regression": @@ -88,24 +76,12 @@ def __getitem__(self, idx: int): composition = row[self.inputs] material_ids = row[self.identifiers].to_list() - comp_dict = Composition(composition).get_el_amt_dict() - elements = list(comp_dict) - + comp_dict = Composition(composition).fractional_composition weights = list(comp_dict.values()) weights = np.atleast_2d(weights).T / np.sum(weights) + elem_fea = [elem.Z for elem in comp_dict] - try: - elem_fea = np.vstack([self.elem_features[element] for element in elements]) - except AssertionError as exc: - raise AssertionError( - f"{material_ids} contains element types not in embedding" - ) from exc - except ValueError as exc: - raise ValueError( - f"{material_ids} composition cannot be parsed into elements" - ) from exc - - n_elems = len(elements) + n_elems = len(comp_dict) self_idx = [] nbr_idx = [] for elem_idx in range(n_elems): @@ -114,7 +90,7 @@ def __getitem__(self, idx: int): # convert all data to tensors elem_weights = Tensor(weights) - elem_fea = Tensor(elem_fea) + elem_fea = LongTensor(elem_fea) self_idx = LongTensor(self_idx) nbr_idx = LongTensor(nbr_idx) diff --git a/aviary/roost/model.py b/aviary/roost/model.py index fec22196..b1c3cbcb 100644 --- a/aviary/roost/model.py +++ b/aviary/roost/model.py @@ -1,10 +1,13 @@ +import json from collections.abc import Sequence import torch import torch.nn.functional as F +from pymatgen.core import Element from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn +from aviary import PKG_DIR from aviary.core import BaseModelClass from aviary.networks import ResidualNetwork, SimpleNetwork from aviary.segments import MessageLayer, WeightedAttentionPooling @@ -25,7 +28,7 @@ def __init__( self, robust: bool, n_targets: Sequence[int], - elem_emb_len: int, + elem_embedding: str = "matscholar200", elem_fea_len: int = 64, n_graph: int = 3, elem_heads: int = 3, @@ -41,6 +44,21 @@ def __init__( """Composition-only model.""" super().__init__(robust=robust, **kwargs) + if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: + elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" + + with open(elem_embedding) as file: + self.elem_features = json.load(file) + + max_z = max(Element(elem).Z for elem in self.elem_features) + elem_emb_len = len(next(iter(self.elem_features.values()))) + elem_feature_matrix = torch.zeros((max_z + 1, elem_emb_len)) + for elem, feature in self.elem_features.items(): + elem_feature_matrix[Element(elem).Z] = torch.tensor(feature) + + self.elem_embedding = nn.Embedding(max_z + 1, elem_emb_len) + self.elem_embedding.weight.data.copy_(elem_feature_matrix) + desc_dict = { "elem_emb_len": elem_emb_len, "elem_fea_len": elem_fea_len, @@ -60,6 +78,7 @@ def __init__( "n_targets": n_targets, "out_hidden": out_hidden, "trunk_hidden": trunk_hidden, + "elem_embedding": elem_embedding, **desc_dict, } self.model_params.update(model_params) @@ -83,6 +102,8 @@ def forward( cry_elem_idx: LongTensor, ) -> tuple[Tensor, ...]: """Forward pass through the material_nn and output_nn.""" + elem_fea = self.elem_embedding(elem_fea) + crys_fea = self.material_nn( elem_weights, elem_fea, self_idx, nbr_idx, cry_elem_idx ) diff --git a/examples/notebooks/Roost.ipynb b/examples/notebooks/Roost.ipynb index ec802b15..6a9f7989 100644 --- a/examples/notebooks/Roost.ipynb +++ b/examples/notebooks/Roost.ipynb @@ -56,7 +56,62 @@ "execution_count": null, "id": "2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n" + ] + } + ], "source": [ "with gzip.open(\"taata.json.gz\", \"r\") as fin:\n", " json_bytes = fin.read()\n", @@ -149,7 +204,330 @@ "execution_count": null, "id": "4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "using 0.2 of training set as test set\n", + "No validation set used, using test set for evaluation purposes\n", + "Total Number of Trainable Parameters: 973,777\n", + "Dummy MAE: 1.2757\n", + "Epoch: [0/99]\n", + " train: E_vasp_per_atom N 10 MAE 1.27 Loss 1.12 RMSE 1.59 \n", + " evaluate: E_vasp_per_atom N 1 MAE 1.29 Loss 1.13 RMSE 1.59 \n", + "Epoch: [1/99]\n", + " train: E_vasp_per_atom N 10 MAE 1.25 Loss 1.11 RMSE 1.59 \n", + " evaluate: E_vasp_per_atom N 1 MAE 1.25 Loss 1.10 RMSE 1.55 \n", + "Epoch: [2/99]\n", + " train: E_vasp_per_atom N 10 MAE 1.17 Loss 1.03 RMSE 1.50 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.98 Loss 0.87 RMSE 1.30 \n", + "Epoch: [3/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.85 Loss 0.74 RMSE 1.23 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.62 Loss 0.49 RMSE 0.87 \n", + "Epoch: [4/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.59 Loss 0.40 RMSE 0.89 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.44 Loss 0.13 RMSE 0.59 \n", + "Epoch: [5/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.45 Loss 0.10 RMSE 0.69 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.39 Loss -0.06 RMSE 0.54 \n", + "Epoch: [6/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.44 Loss 0.03 RMSE 0.68 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.37 Loss -0.10 RMSE 0.50 \n", + "Epoch: [7/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.40 Loss -0.03 RMSE 0.63 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.36 Loss -0.14 RMSE 0.50 \n", + "Epoch: [8/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.38 Loss -0.10 RMSE 0.61 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.38 Loss -0.07 RMSE 0.49 \n", + "Epoch: [9/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.38 Loss -0.09 RMSE 0.61 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.36 Loss -0.13 RMSE 0.46 \n", + "Epoch: [10/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.36 Loss -0.15 RMSE 0.57 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.31 RMSE 0.42 \n", + "Epoch: [11/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.35 Loss -0.20 RMSE 0.56 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.32 Loss -0.26 RMSE 0.42 \n", + "Epoch: [12/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.36 Loss -0.17 RMSE 0.57 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.33 RMSE 0.43 \n", + "Epoch: [13/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.35 Loss -0.18 RMSE 0.56 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.32 RMSE 0.42 \n", + "Epoch: [14/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.38 RMSE 0.41 \n", + "Epoch: [15/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.31 Loss -0.32 RMSE 0.41 \n", + "Epoch: [16/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.33 RMSE 0.41 \n", + "Epoch: [17/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.42 \n", + "Epoch: [18/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.26 RMSE 0.53 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.40 \n", + "Epoch: [19/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.31 RMSE 0.51 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.39 RMSE 0.39 \n", + "Epoch: [20/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.29 RMSE 0.52 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.39 \n", + "Epoch: [21/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.50 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.40 \n", + "Epoch: [22/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.30 RMSE 0.50 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.34 Loss -0.20 RMSE 0.42 \n", + "Epoch: [23/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.35 Loss -0.20 RMSE 0.51 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.38 RMSE 0.41 \n", + "Epoch: [24/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.28 RMSE 0.50 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.38 RMSE 0.41 \n", + "Epoch: [25/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.48 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.34 RMSE 0.42 \n", + "Epoch: [26/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.31 RMSE 0.48 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.37 RMSE 0.39 \n", + "Epoch: [27/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.48 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.37 RMSE 0.39 \n", + "Epoch: [28/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.47 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.37 RMSE 0.39 \n", + "Epoch: [29/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.47 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.41 RMSE 0.38 \n", + "Epoch: [30/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.46 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.33 RMSE 0.39 \n", + "Epoch: [31/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.33 RMSE 0.46 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.44 RMSE 0.39 \n", + "Epoch: [32/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.31 RMSE 0.46 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.41 RMSE 0.38 \n", + "Epoch: [33/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.32 RMSE 0.45 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.41 RMSE 0.41 \n", + "Epoch: [34/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.30 RMSE 0.45 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.35 Loss -0.20 RMSE 0.42 \n", + "Epoch: [35/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.30 RMSE 0.45 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.37 RMSE 0.41 \n", + "Epoch: [36/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.47 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.40 \n", + "Epoch: [37/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.35 RMSE 0.44 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.37 \n", + "Epoch: [38/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.42 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.41 RMSE 0.36 \n", + "Epoch: [39/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.36 RMSE 0.43 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.39 \n", + "Epoch: [40/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.42 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.36 \n", + "Epoch: [41/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.36 \n", + "Epoch: [42/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.44 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.40 RMSE 0.36 \n", + "Epoch: [43/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.34 RMSE 0.42 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.31 Loss -0.33 RMSE 0.43 \n", + "Epoch: [44/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.37 RMSE 0.41 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.36 \n", + "Epoch: [45/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.41 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.33 RMSE 0.37 \n", + "Epoch: [46/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.37 RMSE 0.41 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.44 RMSE 0.40 \n", + "Epoch: [47/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.41 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.43 RMSE 0.36 \n", + "Epoch: [48/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.36 \n", + "Epoch: [49/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.45 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.36 \n", + "Epoch: [50/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.46 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.35 \n", + "Epoch: [51/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.45 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.51 RMSE 0.35 \n", + "Epoch: [52/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.37 \n", + "Epoch: [53/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.46 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.45 RMSE 0.36 \n", + "Epoch: [54/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.45 RMSE 0.39 \n", + "Epoch: [55/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.36 \n", + "Epoch: [56/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.41 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.40 RMSE 0.36 \n", + "Epoch: [57/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.41 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.45 RMSE 0.35 \n", + "Epoch: [58/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.36 \n", + "Epoch: [59/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.45 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.34 \n", + "Epoch: [60/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", + "Epoch: [61/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.39 \n", + "Epoch: [62/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.42 RMSE 0.36 \n", + "Epoch: [63/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.31 Loss -0.28 RMSE 0.44 \n", + "Epoch: [64/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.36 RMSE 0.42 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.35 \n", + "Epoch: [65/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.37 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.42 RMSE 0.36 \n", + "Epoch: [66/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.35 RMSE 0.41 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.36 RMSE 0.42 \n", + "Epoch: [67/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.38 \n", + "Epoch: [68/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.33 RMSE 0.37 \n", + "Epoch: [69/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.50 RMSE 0.36 \n", + "Epoch: [70/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.48 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.50 RMSE 0.35 \n", + "Epoch: [71/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.37 \n", + "Epoch: [72/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.36 \n", + "Epoch: [73/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", + "Epoch: [74/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.37 \n", + "Epoch: [75/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.53 RMSE 0.35 \n", + "Epoch: [76/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.52 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.35 \n", + "Epoch: [77/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.35 \n", + "Epoch: [78/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.48 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.51 RMSE 0.37 \n", + "Epoch: [79/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.50 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.54 RMSE 0.35 \n", + "Epoch: [80/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.52 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.38 \n", + "Epoch: [81/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.50 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.54 RMSE 0.35 \n", + "Epoch: [82/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.51 RMSE 0.36 \n", + "Epoch: [83/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.48 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.53 RMSE 0.35 \n", + "Epoch: [84/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.35 \n", + "Epoch: [85/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.41 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.38 \n", + "Epoch: [86/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.50 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.35 \n", + "Epoch: [87/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.36 \n", + "Epoch: [88/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.45 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.36 \n", + "Epoch: [89/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.36 \n", + "Epoch: [90/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.48 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", + "Epoch: [91/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.51 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", + "Epoch: [92/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.50 RMSE 0.36 \n", + "Epoch: [93/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.38 RMSE 0.36 \n", + "Epoch: [94/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.53 RMSE 0.35 \n", + "Epoch: [95/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.51 RMSE 0.37 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.53 RMSE 0.36 \n", + "Epoch: [96/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.34 \n", + "Epoch: [97/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.40 \n", + "Epoch: [98/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.53 RMSE 0.35 \n", + "Epoch: [99/99]\n", + " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.47 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.44 RMSE 0.36 \n" + ] + }, + { + "ename": "TypeError", + "evalue": "results_multitask() got an unexpected keyword argument 'test_set'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 73\u001b[0m\n\u001b[1;32m 53\u001b[0m train_ensemble(\n\u001b[1;32m 54\u001b[0m model_class\u001b[38;5;241m=\u001b[39mRoost,\n\u001b[1;32m 55\u001b[0m model_name\u001b[38;5;241m=\u001b[39mmodel_name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 65\u001b[0m loss_dict\u001b[38;5;241m=\u001b[39mloss_dict,\n\u001b[1;32m 66\u001b[0m )\n\u001b[1;32m 68\u001b[0m test_loader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 69\u001b[0m test_set,\n\u001b[1;32m 70\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m{\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_params, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m \u001b[38;5;241m*\u001b[39m data_params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshuffle\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m},\n\u001b[1;32m 71\u001b[0m )\n\u001b[0;32m---> 73\u001b[0m roost_results_dict \u001b[38;5;241m=\u001b[39m \u001b[43mresults_multitask\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRoost\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[43m \u001b[49m\u001b[43mrun_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mensemble_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensemble\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_set\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m \u001b[49m\u001b[43mrobust\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrobust\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43mtask_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcheckpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: results_multitask() got an unexpected keyword argument 'test_set'" + ] + } + ], "source": [ "torch.manual_seed(0) # ensure reproducible results\n", "\n", @@ -161,11 +539,9 @@ "\n", "dataset = CompositionData(\n", " df=df,\n", - " elem_embedding=elem_embedding,\n", " task_dict=task_dict,\n", ")\n", "n_targets = dataset.n_targets\n", - "elem_emb_len = dataset.elem_emb_len\n", "\n", "train_idx = list(range(len(dataset)))\n", "\n", @@ -192,7 +568,7 @@ " \"task_dict\": task_dict,\n", " \"robust\": robust,\n", " \"n_targets\": n_targets,\n", - " \"elem_emb_len\": elem_emb_len,\n", + " \"elem_embedding\": elem_embedding,\n", " \"elem_fea_len\": 64,\n", " \"n_graph\": 3,\n", " \"elem_heads\": 3,\n", @@ -211,16 +587,40 @@ " run_id=run_id,\n", " ensemble_folds=ensemble,\n", " epochs=epochs,\n", - " train_set=train_set,\n", - " val_set=val_set,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", " log=log,\n", - " data_params=data_params,\n", " setup_params=setup_params,\n", " restart_params=restart_params,\n", " model_params=model_params,\n", " loss_dict=loss_dict,\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "------------Evaluate model on Test Set------------\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "Evaluating Model\n", + "\n", + "Task: target_name='E_vasp_per_atom' on test set\n", + "Model Performance Metrics:\n", + "R2 Score: 0.9494 \n", + "MAE: 0.2701\n", + "RMSE: 0.3576\n" + ] + } + ], + "source": [ "test_loader = DataLoader(\n", " test_set,\n", " **{**data_params, \"batch_size\": 64 * data_params[\"batch_size\"], \"shuffle\": False},\n", @@ -231,8 +631,7 @@ " model_name=model_name,\n", " run_id=run_id,\n", " ensemble_folds=ensemble,\n", - " test_set=test_set,\n", - " data_params=data_params,\n", + " test_loader=test_loader,\n", " robust=robust,\n", " task_dict=task_dict,\n", " device=device,\n", diff --git a/examples/roost-example.py b/examples/roost-example.py index effd0716..c8410e77 100644 --- a/examples/roost-example.py +++ b/examples/roost-example.py @@ -86,9 +86,8 @@ def main( # NOTE do not use default_na as "NaN" is a valid material df = pd.read_csv(data_path, keep_default_na=False, na_values=[]) - dataset = CompositionData(df=df, elem_embedding=elem_embedding, task_dict=task_dict) + dataset = CompositionData(df=df, task_dict=task_dict) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len train_idx = list(range(len(dataset))) @@ -99,9 +98,7 @@ def main( df = pd.read_csv(test_path, keep_default_na=False, na_values=[]) print(f"using independent test set: {test_path}") - test_set = CompositionData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict - ) + test_set = CompositionData(df=df, task_dict=task_dict) test_set = torch.utils.data.Subset(test_set, range(len(test_set))) elif test_size == 0: raise ValueError("test-size must be non-zero to evaluate model") @@ -119,9 +116,7 @@ def main( df = pd.read_csv(val_path, keep_default_na=False, na_values=[]) print(f"using independent validation set: {val_path}") - val_set = CompositionData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict - ) + val_set = CompositionData(df=df, task_dict=task_dict) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) elif val_size == 0 and evaluate: print("No validation set used, using test set for evaluation purposes") @@ -172,7 +167,7 @@ def main( "task_dict": task_dict, "robust": robust, "n_targets": n_targets, - "elem_emb_len": elem_emb_len, + "elem_embedding": elem_embedding, "elem_fea_len": elem_fea_len, "n_graph": n_graph, "elem_heads": 3, diff --git a/tests/test_roost.py b/tests/test_roost.py index 7ff74ec4..8be8b594 100644 --- a/tests/test_roost.py +++ b/tests/test_roost.py @@ -69,11 +69,9 @@ def test_roost_regression( dataset = CompositionData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -111,7 +109,7 @@ def test_roost_regression( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, + "elem_embedding": base_config["elem_embedding"], **model_architecture, # unpack all model architecture parameters } @@ -179,11 +177,9 @@ def test_roost_clf(df_matbench_phonons, base_config, model_architecture, trainin dataset = CompositionData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -221,7 +217,7 @@ def test_roost_clf(df_matbench_phonons, base_config, model_architecture, trainin "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, + "elem_embedding": base_config["elem_embedding"], **model_architecture, # unpack all model architecture parameters } From 1576e903d976515f6cf32292d2e343c401dc3832 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 12 Apr 2025 18:53:25 -0400 Subject: [PATCH 02/12] fea: swap wren to use nn.embedding --- aviary/wren/data.py | 68 ++---- aviary/wren/model.py | 47 +++- examples/notebooks/Wren.ipynb | 407 ++++++++++++++++++++++++++++++++-- examples/wren-example.py | 16 +- tests/test_wren.py | 18 +- 5 files changed, 470 insertions(+), 86 deletions(-) diff --git a/aviary/wren/data.py b/aviary/wren/data.py index fe16937c..b38c7a21 100644 --- a/aviary/wren/data.py +++ b/aviary/wren/data.py @@ -1,4 +1,5 @@ import json +from collections import defaultdict from collections.abc import Sequence from functools import cache from itertools import groupby @@ -13,11 +14,23 @@ WYCKOFF_MULTIPLICITY_DICT, WYCKOFF_POSITION_RELAB_DICT, ) +from pymatgen.core import Element from torch import LongTensor, Tensor from torch.utils.data import Dataset from aviary import PKG_DIR +with open(f"{PKG_DIR}/embeddings/wyckoff/bra-alg-off.json") as f: + sym_embeddings = json.load(f) +WYCKOFF_SPG_LETTER_MAP: dict[str, dict[str, int]] = defaultdict(dict) +i = 0 +for spg_num, embeddings in sym_embeddings.items(): + for wyckoff_letter in embeddings: + WYCKOFF_SPG_LETTER_MAP[spg_num][wyckoff_letter] = i + i += 1 + +del sym_embeddings + class WyckoffData(Dataset): """Wyckoff dataset class for the Wren model.""" @@ -26,8 +39,6 @@ def __init__( self, df: pd.DataFrame, task_dict: dict[str, str], - elem_embedding: str = "matscholar200", - sym_emb: str = "bra-alg-off", inputs: str = "protostructure", identifiers: Sequence[str] = ("material_id", "composition", "protostructure"), ): @@ -37,11 +48,6 @@ def __init__( df (pd.DataFrame): Pandas dataframe holding input and target values. task_dict (dict[str, "regression" | "classification"]): Map from target names to task type for multi-task learning. - elem_embedding (str, optional): One of "matscholar200", "cgcnn92", - "megnet16", "onehot112" or path to a file with custom element - embeddings. Defaults to "matscholar200". - sym_emb (str): Symmetry embedding. One of "bra-alg-off" (default) or - "spg-alg-off" or path to a file with custom symmetry embeddings. inputs (str, optional): df columns to be used for featurization. Defaults to "protostructure". identifiers (list, optional): df columns for distinguishing data points. @@ -56,24 +62,6 @@ def __init__( self.identifiers = list(identifiers) self.df = df - if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"): - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as emb_file: - self.elem_features = json.load(emb_file) - - self.elem_emb_len = len(next(iter(self.elem_features.values()))) - - if sym_emb in ("bra-alg-off", "spg-alg-off"): - sym_emb = f"{PKG_DIR}/embeddings/wyckoff/{sym_emb}.json" - - with open(sym_emb) as sym_file: - self.sym_features = json.load(sym_file) - - self.sym_emb_len = len( - next(iter(next(iter(self.sym_features.values())).values())) - ) - self.n_targets = [] for target, task in self.task_dict.items(): if task == "regression": @@ -113,23 +101,13 @@ def __getitem__(self, idx: int): wyk_site_multiplcities ) - try: - element_features = np.vstack([self.elem_features[el] for el in elements]) - except AssertionError: - print(f"Failed to process elements for {material_ids}") - raise - - try: - symmetry_features = np.vstack( - [ - self.sym_features[spg_num][wyk_site] - for wyckoff_sites in augmented_wyks - for wyk_site in wyckoff_sites - ] - ) - except AssertionError: - print(f"Failed to process Wyckoff positions for {material_ids}") - raise + element_features = [Element(el).Z for el in elements] + + symmetry_features = [ + WYCKOFF_SPG_LETTER_MAP[spg_num][wyk_site] + for wyckoff_sites in augmented_wyks + for wyk_site in wyckoff_sites + ] n_wyks = len(elements) self_idx = [] @@ -147,8 +125,8 @@ def __getitem__(self, idx: int): # convert all data to tensors wyckoff_weights = Tensor(wyk_site_multiplcities) - element_features = Tensor(element_features) - symmetry_features = Tensor(symmetry_features) + element_features = LongTensor(element_features) + symmetry_features = LongTensor(symmetry_features) self_idx = LongTensor(self_aug_fea_idx) nbr_idx = LongTensor(nbr_aug_fea_idx) @@ -198,7 +176,7 @@ def collate_batch( # batch the features together batch_mult_weights.append(mult_weights.repeat((n_aug, 1))) - batch_elem_fea.append(elem_fea.repeat((n_aug, 1))) + batch_elem_fea.append(elem_fea.repeat(n_aug)) batch_sym_fea.append(sym_fea) # mappings from bonds to atoms diff --git a/aviary/wren/model.py b/aviary/wren/model.py index 6e6da581..aba96c0c 100644 --- a/aviary/wren/model.py +++ b/aviary/wren/model.py @@ -1,10 +1,13 @@ +import json from collections.abc import Sequence import torch import torch.nn.functional as F +from pymatgen.core import Element from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn +from aviary import PKG_DIR from aviary.core import BaseModelClass from aviary.networks import ResidualNetwork, SimpleNetwork from aviary.scatter import scatter_reduce @@ -26,8 +29,8 @@ def __init__( self, robust: bool, n_targets: Sequence[int], - elem_emb_len: int, - sym_emb_len: int, + elem_embedding: str = "matscholar200", + sym_embedding: str = "bra-alg-off", elem_fea_len: int = 32, sym_fea_len: int = 32, n_graph: int = 3, @@ -44,6 +47,42 @@ def __init__( """Protostructure based model.""" super().__init__(robust=robust, **kwargs) + # load the element embedding + if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: + elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" + + with open(elem_embedding) as file: + self.elem_features = json.load(file) + + max_z = max(Element(elem).Z for elem in self.elem_features) + elem_emb_len = len(next(iter(self.elem_features.values()))) + elem_feature_matrix = torch.zeros((max_z + 1, elem_emb_len)) + for elem, feature in self.elem_features.items(): + elem_feature_matrix[Element(elem).Z] = torch.tensor(feature) + + self.elem_embedding = nn.Embedding(max_z + 1, elem_emb_len) + self.elem_embedding.weight.data.copy_(elem_feature_matrix) + + # load the Wyckoff embedding + if sym_embedding in ("bra-alg-off", "spg-alg-off"): + sym_embedding = f"{PKG_DIR}/embeddings/wyckoff/{sym_embedding}.json" + + with open(sym_embedding) as sym_file: + self.sym_features = json.load(sym_file) + + sym_emb_len = len(next(iter(next(iter(self.sym_features.values())).values()))) + + len_sym_features = sum(len(feature) for feature in self.sym_features.values()) + sym_feature_matrix = torch.zeros((len_sym_features, sym_emb_len)) + sym_idx = 0 + for embeddings in self.sym_features.values(): + for feature in embeddings.values(): + sym_feature_matrix[sym_idx] = torch.tensor(feature) + sym_idx += 1 + + self.sym_embedding = nn.Embedding(len_sym_features, sym_emb_len) + self.sym_embedding.weight.data.copy_(sym_feature_matrix) + desc_dict = { "elem_emb_len": elem_emb_len, "elem_fea_len": elem_fea_len, @@ -62,6 +101,8 @@ def __init__( model_params = { "robust": robust, + "elem_embedding": elem_embedding, + "sym_embedding": sym_embedding, "n_targets": n_targets, "out_hidden": out_hidden, "trunk_hidden": trunk_hidden, @@ -92,6 +133,8 @@ def forward( aug_cry_idx: LongTensor, ) -> tuple[Tensor, ...]: """Forward pass through the material_nn and output_nn.""" + elem_fea = self.elem_embedding(elem_fea) + sym_fea = self.sym_embedding(sym_fea) crys_fea = self.material_nn( elem_weights, elem_fea, diff --git a/examples/notebooks/Wren.ipynb b/examples/notebooks/Wren.ipynb index 3a41df2c..72364bb3 100644 --- a/examples/notebooks/Wren.ipynb +++ b/examples/notebooks/Wren.ipynb @@ -56,7 +56,62 @@ "execution_count": null, "id": "2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n", + "spglib: ssm_get_exact_positions failed.\n", + "spglib: get_bravais_exact_positions_and_lattice failed.\n" + ] + } + ], "source": [ "with gzip.open(\"taata.json.gz\", \"r\") as fin:\n", " json_bytes = fin.read()\n", @@ -75,7 +130,12 @@ "df = df[df.protostructure.map(count_wyckoff_positions) < 16]\n", "df[\"n_sites\"] = df.final_structure.map(len)\n", "df = df[df.n_sites < 64]\n", - "df = df[df.volume_per_atom < 500]" + "df = df[df.volume_per_atom < 500]\n", + "\n", + "# NOTE for roost we keep only the lowest lying structures for each composition\n", + "df = df.sort_values([\"protostructure\", \"E_vasp_per_atom\"]).drop_duplicates(\n", + " \"protostructure\", keep=\"first\"\n", + ")" ] }, { @@ -144,23 +204,310 @@ "execution_count": null, "id": "4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "using 0.2 of training set as test set\n", + "No validation set used, using test set for evaluation purposes\n", + "Total Number of Trainable Parameters: 1,154,330\n", + "Dummy MAE: 0.9670\n", + "Epoch: [0/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.95 Loss 1.11 RMSE 1.19 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.84 Loss 0.97 RMSE 1.07 \n", + "Epoch: [1/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.54 Loss 0.52 RMSE 0.73 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.36 Loss 0.18 RMSE 0.50 \n", + "Epoch: [2/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.35 Loss 0.13 RMSE 0.48 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.34 Loss 0.07 RMSE 0.47 \n", + "Epoch: [3/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.32 Loss 0.01 RMSE 0.44 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.32 Loss 0.01 RMSE 0.44 \n", + "Epoch: [4/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.31 Loss -0.02 RMSE 0.43 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.34 Loss 0.08 RMSE 0.47 \n", + "Epoch: [5/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.30 Loss -0.04 RMSE 0.42 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.06 RMSE 0.41 \n", + "Epoch: [6/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.28 Loss -0.12 RMSE 0.40 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.10 RMSE 0.41 \n", + "Epoch: [7/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.27 Loss -0.15 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.10 RMSE 0.41 \n", + "Epoch: [8/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.27 Loss -0.15 RMSE 0.39 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.07 RMSE 0.43 \n", + "Epoch: [9/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.26 Loss -0.19 RMSE 0.38 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.16 RMSE 0.40 \n", + "Epoch: [10/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.25 Loss -0.25 RMSE 0.36 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.18 RMSE 0.39 \n", + "Epoch: [11/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.25 Loss -0.27 RMSE 0.36 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.10 RMSE 0.39 \n", + "Epoch: [12/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.28 RMSE 0.35 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.21 RMSE 0.37 \n", + "Epoch: [13/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.31 RMSE 0.35 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.22 RMSE 0.37 \n", + "Epoch: [14/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.31 RMSE 0.34 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.10 RMSE 0.40 \n", + "Epoch: [15/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.30 RMSE 0.34 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.07 RMSE 0.41 \n", + "Epoch: [16/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.30 RMSE 0.34 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.18 RMSE 0.38 \n", + "Epoch: [17/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.40 RMSE 0.33 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.13 RMSE 0.37 \n", + "Epoch: [18/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.41 RMSE 0.32 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.27 RMSE 0.36 \n", + "Epoch: [19/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.42 RMSE 0.32 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.27 RMSE 0.36 \n", + "Epoch: [20/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.40 RMSE 0.32 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.30 RMSE 0.35 \n", + "Epoch: [21/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.21 Loss -0.46 RMSE 0.31 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.32 RMSE 0.34 \n", + "Epoch: [22/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.21 Loss -0.48 RMSE 0.31 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.29 RMSE 0.35 \n", + "Epoch: [23/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.52 RMSE 0.30 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.17 RMSE 0.35 \n", + "Epoch: [24/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.53 RMSE 0.30 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.33 RMSE 0.34 \n", + "Epoch: [25/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.53 RMSE 0.30 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.30 RMSE 0.34 \n", + "Epoch: [26/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.53 RMSE 0.30 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.34 RMSE 0.33 \n", + "Epoch: [27/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.55 RMSE 0.29 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.31 RMSE 0.34 \n", + "Epoch: [28/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.56 RMSE 0.29 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.29 RMSE 0.33 \n", + "Epoch: [29/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.58 RMSE 0.29 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.34 RMSE 0.33 \n", + "Epoch: [30/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.55 RMSE 0.29 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.34 RMSE 0.33 \n", + "Epoch: [31/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.57 RMSE 0.29 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.35 RMSE 0.33 \n", + "Epoch: [32/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.61 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.31 RMSE 0.33 \n", + "Epoch: [33/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.62 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.28 RMSE 0.33 \n", + "Epoch: [34/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.60 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.32 RMSE 0.32 \n", + "Epoch: [35/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.62 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss 0.11 RMSE 0.40 \n", + "Epoch: [36/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.62 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.21 RMSE 0.35 \n", + "Epoch: [37/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.64 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.34 RMSE 0.32 \n", + "Epoch: [38/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.63 RMSE 0.28 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.35 RMSE 0.33 \n", + "Epoch: [39/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.66 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", + "Epoch: [40/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.69 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.34 RMSE 0.33 \n", + "Epoch: [41/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.68 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", + "Epoch: [42/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.70 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.33 RMSE 0.32 \n", + "Epoch: [43/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.35 RMSE 0.32 \n", + "Epoch: [44/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.72 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", + "Epoch: [45/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.32 \n", + "Epoch: [46/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.76 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.35 RMSE 0.31 \n", + "Epoch: [47/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.70 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.39 RMSE 0.31 \n", + "Epoch: [48/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.71 RMSE 0.27 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.31 RMSE 0.33 \n", + "Epoch: [49/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.73 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.38 RMSE 0.31 \n", + "Epoch: [50/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.79 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.31 \n", + "Epoch: [51/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", + "Epoch: [52/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.77 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.36 RMSE 0.31 \n", + "Epoch: [53/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.78 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.31 \n", + "Epoch: [54/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.79 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.33 RMSE 0.32 \n", + "Epoch: [55/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.76 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.31 \n", + "Epoch: [56/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.82 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.39 RMSE 0.31 \n", + "Epoch: [57/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.81 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.31 \n", + "Epoch: [58/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.81 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.43 RMSE 0.32 \n", + "Epoch: [59/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.73 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.42 RMSE 0.31 \n", + "Epoch: [60/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.82 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.38 RMSE 0.31 \n", + "Epoch: [61/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.79 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.41 RMSE 0.31 \n", + "Epoch: [62/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.30 \n", + "Epoch: [63/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.85 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.39 RMSE 0.31 \n", + "Epoch: [64/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.31 \n", + "Epoch: [65/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.86 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.36 RMSE 0.30 \n", + "Epoch: [66/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.42 RMSE 0.30 \n", + "Epoch: [67/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.40 RMSE 0.31 \n", + "Epoch: [68/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.88 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.34 RMSE 0.30 \n", + "Epoch: [69/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.88 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.23 RMSE 0.33 \n", + "Epoch: [70/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.89 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss 0.01 RMSE 0.33 \n", + "Epoch: [71/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.26 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.32 RMSE 0.31 \n", + "Epoch: [72/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.92 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.46 RMSE 0.30 \n", + "Epoch: [73/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.94 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.28 RMSE 0.31 \n", + "Epoch: [74/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.90 RMSE 0.25 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.37 RMSE 0.30 \n", + "Epoch: [75/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.28 RMSE 0.31 \n", + "Epoch: [76/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.90 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.30 \n", + "Epoch: [77/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.91 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.46 RMSE 0.30 \n", + "Epoch: [78/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.94 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.32 RMSE 0.30 \n", + "Epoch: [79/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.98 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.29 RMSE 0.31 \n", + "Epoch: [80/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.96 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.24 RMSE 0.31 \n", + "Epoch: [81/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.91 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.38 RMSE 0.31 \n", + "Epoch: [82/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.96 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.38 RMSE 0.31 \n", + "Epoch: [83/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.94 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.38 RMSE 0.31 \n", + "Epoch: [84/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.97 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.34 RMSE 0.30 \n", + "Epoch: [85/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.99 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.31 RMSE 0.31 \n", + "Epoch: [86/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.00 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.36 RMSE 0.30 \n", + "Epoch: [87/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.92 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.19 Loss -0.41 RMSE 0.30 \n", + "Epoch: [88/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.02 RMSE 0.23 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.30 RMSE 0.30 \n", + "Epoch: [89/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.95 RMSE 0.24 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.31 RMSE 0.30 \n", + "Epoch: [90/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.01 RMSE 0.23 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.19 Loss -0.35 RMSE 0.30 \n", + "Epoch: [91/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.01 RMSE 0.23 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.19 Loss -0.38 RMSE 0.30 \n", + "Epoch: [92/99]\n", + " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.03 RMSE 0.23 \n", + " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.32 RMSE 0.30 \n", + "Epoch: [93/99]\n" + ] + } + ], "source": [ "torch.manual_seed(0) # ensure reproducible results\n", "\n", "elem_embedding = \"matscholar200\"\n", - "sym_emb = \"bra-alg-off\"\n", + "sym_embedding = \"bra-alg-off\"\n", "model_name = \"wren-reg-test\"\n", "\n", "data_params[\"collate_fn\"] = wren_cb\n", "data_params[\"shuffle\"] = True\n", "\n", - "dataset = WyckoffData(\n", - " df=df, elem_embedding=elem_embedding, sym_emb=sym_emb, task_dict=task_dict\n", - ")\n", + "dataset = WyckoffData(df=df, task_dict=task_dict)\n", "n_targets = dataset.n_targets\n", - "elem_emb_len = dataset.elem_emb_len\n", - "sym_emb_len = dataset.sym_emb_len\n", "\n", "train_idx = list(range(len(dataset)))\n", "\n", @@ -187,9 +534,9 @@ " \"task_dict\": task_dict,\n", " \"robust\": robust,\n", " \"n_targets\": n_targets,\n", - " \"elem_emb_len\": elem_emb_len,\n", + " \"elem_embedding\": elem_embedding,\n", " \"elem_fea_len\": 32,\n", - " \"sym_emb_len\": sym_emb_len,\n", + " \"sym_embedding\": sym_embedding,\n", " \"sym_fea_len\": 32,\n", " \"n_graph\": 3,\n", " \"elem_heads\": 1,\n", @@ -215,8 +562,40 @@ " restart_params=restart_params,\n", " model_params=model_params,\n", " loss_dict=loss_dict,\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "------------Evaluate model on Test Set------------\n", + "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "\n", + "Evaluating Model\n" + ] + }, + { + "ename": "ValueError", + "evalue": "task_dict {'last phdos peak': 'regression'} of checkpoint resume='/Users/radical-rhys/Radical/aviary/models/wren-reg-test/checkpoint-r1.pth.tar' does not match provided task_dict={'E_vasp_per_atom': 'regression'}", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[16], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m test_loader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 2\u001b[0m test_set,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m{\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_params, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m \u001b[38;5;241m*\u001b[39m data_params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshuffle\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m},\n\u001b[1;32m 4\u001b[0m )\n\u001b[0;32m----> 6\u001b[0m roost_results_dict \u001b[38;5;241m=\u001b[39m \u001b[43mresults_multitask\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mWren\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mrun_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mensemble_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensemble\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mrobust\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrobust\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mtask_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcheckpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Radical/aviary/aviary/utils.py:490\u001b[0m, in \u001b[0;36mresults_multitask\u001b[0;34m(model_class, model_name, run_id, ensemble_folds, test_loader, robust, task_dict, device, eval_type, print_results, save_results)\u001b[0m\n\u001b[1;32m 488\u001b[0m chkpt_task_dict \u001b[38;5;241m=\u001b[39m checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_params\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtask_dict\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 489\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chkpt_task_dict \u001b[38;5;241m!=\u001b[39m task_dict:\n\u001b[0;32m--> 490\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 491\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtask_dict \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mchkpt_task_dict\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m of checkpoint \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresume\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m does not match \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprovided \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtask_dict\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 493\u001b[0m )\n\u001b[1;32m 495\u001b[0m model: BaseModelClass \u001b[38;5;241m=\u001b[39m model_class(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheckpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_params\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 496\u001b[0m model\u001b[38;5;241m.\u001b[39mto(device)\n", + "\u001b[0;31mValueError\u001b[0m: task_dict {'last phdos peak': 'regression'} of checkpoint resume='/Users/radical-rhys/Radical/aviary/models/wren-reg-test/checkpoint-r1.pth.tar' does not match provided task_dict={'E_vasp_per_atom': 'regression'}" + ] + } + ], + "source": [ "test_loader = DataLoader(\n", " test_set,\n", " **{**data_params, \"batch_size\": 64 * data_params[\"batch_size\"], \"shuffle\": False},\n", @@ -253,7 +632,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.12.9" }, "vscode": { "interpreter": { diff --git a/examples/wren-example.py b/examples/wren-example.py index 73031596..de4ceba0 100644 --- a/examples/wren-example.py +++ b/examples/wren-example.py @@ -18,7 +18,7 @@ def main( losses, robust, elem_embedding="matscholar200", - sym_emb="bra-alg-off", + sym_embedding="bra-alg-off", model_name="wren", sym_fea_len=32, elem_fea_len=32, @@ -90,12 +90,8 @@ def main( # NOTE do not use default_na as "NaN" is a valid material composition df = pd.read_csv(data_path, keep_default_na=False, na_values=[]) - dataset = WyckoffData( - df=df, elem_embedding=elem_embedding, sym_emb=sym_emb, task_dict=task_dict - ) + dataset = WyckoffData(df=df, task_dict=task_dict) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - sym_emb_len = dataset.sym_emb_len train_idx = list(range(len(dataset))) @@ -108,8 +104,6 @@ def main( print(f"using independent test set: {test_path}") test_set = WyckoffData( df=df, - elem_embedding=elem_embedding, - sym_emb=sym_emb, task_dict=task_dict, ) test_set = torch.utils.data.Subset(test_set, range(len(test_set))) @@ -131,8 +125,6 @@ def main( print(f"using independent validation set: {val_path}") val_set = WyckoffData( df=df, - elem_embedding=elem_embedding, - sym_emb=sym_emb, task_dict=task_dict, ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) @@ -184,8 +176,8 @@ def main( "task_dict": task_dict, "robust": robust, "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "sym_emb_len": sym_emb_len, + "elem_embedding": elem_embedding, + "sym_embedding": sym_embedding, "elem_fea_len": elem_fea_len, "sym_fea_len": sym_fea_len, "n_graph": n_graph, diff --git a/tests/test_wren.py b/tests/test_wren.py index bd391391..dffd715b 100644 --- a/tests/test_wren.py +++ b/tests/test_wren.py @@ -13,7 +13,7 @@ def base_config(): return { "elem_embedding": "matscholar200", - "sym_emb": "bra-alg-off", + "sym_embedding": "bra-alg-off", "robust": True, "ensemble": 2, "run_id": 1, @@ -71,13 +71,9 @@ def test_wren_regression( dataset = WyckoffData( df=df_matbench_phonons_wyckoff, - elem_embedding=base_config["elem_embedding"], - sym_emb=base_config["sym_emb"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - sym_emb_len = dataset.sym_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -122,8 +118,8 @@ def test_wren_regression( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "sym_emb_len": sym_emb_len, + "elem_embedding": base_config["elem_embedding"], + "sym_embedding": base_config["sym_embedding"], **model_architecture, } @@ -186,13 +182,9 @@ def test_wren_clf( dataset = WyckoffData( df=df_matbench_phonons_wyckoff, - elem_embedding=base_config["elem_embedding"], - sym_emb=base_config["sym_emb"], task_dict=task_dict, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - sym_emb_len = dataset.sym_emb_len train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -237,8 +229,8 @@ def test_wren_clf( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "sym_emb_len": sym_emb_len, + "elem_embedding": base_config["elem_embedding"], + "sym_embedding": base_config["sym_embedding"], **model_architecture, } From c26b7bba633160755e01932bc319b79f564c87d6 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 13 Apr 2025 10:51:40 -0400 Subject: [PATCH 03/12] clean: refactor into utils functions --- aviary/roost/model.py | 21 ++------------ aviary/utils.py | 64 +++++++++++++++++++++++++++++++++++++++++-- aviary/wren/model.py | 42 ++++------------------------ 3 files changed, 70 insertions(+), 57 deletions(-) diff --git a/aviary/roost/model.py b/aviary/roost/model.py index b1c3cbcb..3b84f96e 100644 --- a/aviary/roost/model.py +++ b/aviary/roost/model.py @@ -1,16 +1,14 @@ -import json from collections.abc import Sequence import torch import torch.nn.functional as F -from pymatgen.core import Element from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn -from aviary import PKG_DIR from aviary.core import BaseModelClass from aviary.networks import ResidualNetwork, SimpleNetwork from aviary.segments import MessageLayer, WeightedAttentionPooling +from aviary.utils import get_element_embedding @due.dcite(Doi("10.1038/s41467-020-19964-7"), description="Roost model") @@ -44,21 +42,8 @@ def __init__( """Composition-only model.""" super().__init__(robust=robust, **kwargs) - if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as file: - self.elem_features = json.load(file) - - max_z = max(Element(elem).Z for elem in self.elem_features) - elem_emb_len = len(next(iter(self.elem_features.values()))) - elem_feature_matrix = torch.zeros((max_z + 1, elem_emb_len)) - for elem, feature in self.elem_features.items(): - elem_feature_matrix[Element(elem).Z] = torch.tensor(feature) - - self.elem_embedding = nn.Embedding(max_z + 1, elem_emb_len) - self.elem_embedding.weight.data.copy_(elem_feature_matrix) - + self.elem_embedding = get_element_embedding(elem_embedding) + elem_emb_len = self.elem_embedding.weight.shape[1] desc_dict = { "elem_emb_len": elem_emb_len, "elem_fea_len": elem_fea_len, diff --git a/aviary/utils.py b/aviary/utils.py index 1b924092..45d794ca 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -1,3 +1,4 @@ +import json import os import sys import time @@ -13,6 +14,7 @@ import pandas as pd import torch import wandb +from pymatgen.core import Element from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, @@ -22,13 +24,13 @@ roc_auc_score, ) from torch import LongTensor, Tensor -from torch.nn import CrossEntropyLoss, L1Loss, MSELoss, NLLLoss +from torch.nn import CrossEntropyLoss, Embedding, L1Loss, MSELoss, NLLLoss from torch.optim import SGD, Adam, AdamW, Optimizer from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler from torch.utils.data import DataLoader, Subset from torch.utils.tensorboard import SummaryWriter -from aviary import ROOT +from aviary import PKG_DIR, ROOT from aviary.core import BaseModelClass, Normalizer, TaskType, sampled_softmax from aviary.data import InMemoryDataLoader from aviary.losses import robust_l1_loss, robust_l2_loss @@ -799,6 +801,64 @@ def get_metrics( return {key: round(float(val), prec) for key, val in metrics.items()} +def get_element_embedding(elem_embedding: str) -> Embedding: + """Get an element embedding from a file. + + Args: + elem_embedding (str): The path to the element embedding file. + + Returns: + Embedding: The element embedding. + """ + if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: + elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" + + with open(elem_embedding) as file: + elem_features = json.load(file) + + max_z = max(Element(elem).Z for elem in elem_features) + elem_emb_len = len(next(iter(elem_features.values()))) + elem_feature_matrix = torch.zeros((max_z + 1, elem_emb_len)) + for elem, feature in elem_features.items(): + elem_feature_matrix[Element(elem).Z] = torch.tensor(feature) + + embedding = Embedding(max_z + 1, elem_emb_len) + embedding.weight.data.copy_(elem_feature_matrix) + + return embedding + + +def get_sym_embedding(sym_embedding: str) -> Embedding: + """Get a symmetry embedding from a file. + + Args: + sym_embedding (str): The path to the symmetry embedding file. + + Returns: + Embedding: The symmetry embedding. + """ + if sym_embedding in ("bra-alg-off", "spg-alg-off"): + sym_embedding = f"{PKG_DIR}/embeddings/wyckoff/{sym_embedding}.json" + + with open(sym_embedding) as sym_file: + sym_features = json.load(sym_file) + + sym_emb_len = len(next(iter(next(iter(sym_features.values())).values()))) + + len_sym_features = sum(len(feature) for feature in sym_features.values()) + sym_feature_matrix = torch.zeros((len_sym_features, sym_emb_len)) + sym_idx = 0 + for embeddings in sym_features.values(): + for feature in embeddings.values(): + sym_feature_matrix[sym_idx] = torch.tensor(feature) + sym_idx += 1 + + embedding = Embedding(len_sym_features, sym_emb_len) + embedding.weight.data.copy_(sym_feature_matrix) + + return embedding + + def as_dict_handler(obj: Any) -> dict[str, Any] | None: """Pass this func as json.dump(handler=) or as pandas.to_json(default_handler=).""" try: diff --git a/aviary/wren/model.py b/aviary/wren/model.py index aba96c0c..f89df0c6 100644 --- a/aviary/wren/model.py +++ b/aviary/wren/model.py @@ -1,17 +1,15 @@ -import json from collections.abc import Sequence import torch import torch.nn.functional as F -from pymatgen.core import Element from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn -from aviary import PKG_DIR from aviary.core import BaseModelClass from aviary.networks import ResidualNetwork, SimpleNetwork from aviary.scatter import scatter_reduce from aviary.segments import MessageLayer, WeightedAttentionPooling +from aviary.utils import get_element_embedding, get_sym_embedding @due.dcite(Doi("10.1126/sciadv.abn4117"), description="Wren model") @@ -47,41 +45,11 @@ def __init__( """Protostructure based model.""" super().__init__(robust=robust, **kwargs) - # load the element embedding - if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" + self.elem_embedding = get_element_embedding(elem_embedding) + elem_emb_len = self.elem_embedding.weight.shape[1] - with open(elem_embedding) as file: - self.elem_features = json.load(file) - - max_z = max(Element(elem).Z for elem in self.elem_features) - elem_emb_len = len(next(iter(self.elem_features.values()))) - elem_feature_matrix = torch.zeros((max_z + 1, elem_emb_len)) - for elem, feature in self.elem_features.items(): - elem_feature_matrix[Element(elem).Z] = torch.tensor(feature) - - self.elem_embedding = nn.Embedding(max_z + 1, elem_emb_len) - self.elem_embedding.weight.data.copy_(elem_feature_matrix) - - # load the Wyckoff embedding - if sym_embedding in ("bra-alg-off", "spg-alg-off"): - sym_embedding = f"{PKG_DIR}/embeddings/wyckoff/{sym_embedding}.json" - - with open(sym_embedding) as sym_file: - self.sym_features = json.load(sym_file) - - sym_emb_len = len(next(iter(next(iter(self.sym_features.values())).values()))) - - len_sym_features = sum(len(feature) for feature in self.sym_features.values()) - sym_feature_matrix = torch.zeros((len_sym_features, sym_emb_len)) - sym_idx = 0 - for embeddings in self.sym_features.values(): - for feature in embeddings.values(): - sym_feature_matrix[sym_idx] = torch.tensor(feature) - sym_idx += 1 - - self.sym_embedding = nn.Embedding(len_sym_features, sym_emb_len) - self.sym_embedding.weight.data.copy_(sym_feature_matrix) + self.sym_embedding = get_sym_embedding(sym_embedding) + sym_emb_len = self.sym_embedding.weight.shape[1] desc_dict = { "elem_emb_len": elem_emb_len, From 9a3f34224aa4bf65aab0a81258dbbc2b05a1a15b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 13 Apr 2025 11:11:04 -0400 Subject: [PATCH 04/12] test: add basic tests to the embedding functions --- aviary/utils.py | 14 ++++- tests/test_utils.py | 139 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 2 deletions(-) create mode 100644 tests/test_utils.py diff --git a/aviary/utils.py b/aviary/utils.py index 45d794ca..e1a096b9 100644 --- a/aviary/utils.py +++ b/aviary/utils.py @@ -797,6 +797,8 @@ def get_metrics( metrics["F1"] = f1_score(targets, pred_labels) class1_probas = predictions[:, 1] metrics["ROCAUC"] = roc_auc_score(targets, class1_probas) + else: + raise ValueError(f"Invalid task type: {type}") return {key: round(float(val), prec) for key, val in metrics.items()} @@ -810,8 +812,12 @@ def get_element_embedding(elem_embedding: str) -> Embedding: Returns: Embedding: The element embedding. """ - if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: + if os.path.isfile(elem_embedding): + pass + elif elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]: elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" + else: + raise ValueError(f"Invalid element embedding: {elem_embedding}") with open(elem_embedding) as file: elem_features = json.load(file) @@ -837,8 +843,12 @@ def get_sym_embedding(sym_embedding: str) -> Embedding: Returns: Embedding: The symmetry embedding. """ - if sym_embedding in ("bra-alg-off", "spg-alg-off"): + if os.path.isfile(sym_embedding): + pass + elif sym_embedding in ("bra-alg-off", "spg-alg-off"): sym_embedding = f"{PKG_DIR}/embeddings/wyckoff/{sym_embedding}.json" + else: + raise ValueError(f"Invalid symmetry embedding: {sym_embedding}") with open(sym_embedding) as sym_file: sym_features = json.load(sym_file) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..1a7623ea --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,139 @@ +import json + +import numpy as np +import pandas as pd +import pytest +import torch + +from aviary.utils import get_element_embedding, get_metrics, get_sym_embedding + + +@pytest.fixture +def temp_element_embedding(tmp_path): + embedding_data = { + "H": [1.0, 2.0], + "He": [3.0, 4.0], + "Li": [5.0, 6.0], + } + path = tmp_path / "test_elem_embedding.json" + with open(path, "w") as f: + json.dump(embedding_data, f) + return str(path) + + +@pytest.fixture +def temp_sym_embedding(tmp_path): + embedding_data = { + "1": {"a": [1.0, 2.0], "b": [3.0, 4.0]}, + "2": {"c": [5.0, 6.0]}, + } + path = tmp_path / "test_sym_embedding.json" + with open(path, "w") as f: + json.dump(embedding_data, f) + return str(path) + + +def test_get_element_embedding_custom(temp_element_embedding): + embedding = get_element_embedding(temp_element_embedding) + assert isinstance(embedding, torch.nn.Embedding) + assert embedding.weight.shape == (3 + 1, 2) # max_Z + 1, embedding_dim + assert torch.allclose(embedding.weight[1], torch.tensor([1.0, 2.0])) # H + assert torch.allclose(embedding.weight[2], torch.tensor([3.0, 4.0])) # He + + +def test_get_element_embedding_builtin(): + embedding = get_element_embedding("matscholar200") + assert isinstance(embedding, torch.nn.Embedding) + assert embedding.weight.shape[1] == 200 + + +def test_get_element_embedding_invalid(): + with pytest.raises(ValueError, match="Invalid element embedding: invalid_embedding"): + get_element_embedding("invalid_embedding") + + +def test_get_sym_embedding_custom(temp_sym_embedding): + embedding = get_sym_embedding(temp_sym_embedding) + assert isinstance(embedding, torch.nn.Embedding) + assert embedding.weight.shape == (3, 2) # total features, embedding_dim + assert torch.allclose(embedding.weight[0], torch.tensor([1.0, 2.0])) + assert torch.allclose(embedding.weight[1], torch.tensor([3.0, 4.0])) + + +def test_get_sym_embedding_builtin(): + embedding = get_sym_embedding("bra-alg-off") + assert isinstance(embedding, torch.nn.Embedding) + assert isinstance(embedding.weight, torch.Tensor) + + +def test_get_sym_embedding_invalid(): + with pytest.raises(ValueError, match="Invalid symmetry embedding: invalid_embedding"): + get_sym_embedding("invalid_embedding") + + +def test_regression_metrics(): + targets = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + predictions = np.array([1.1, 2.1, 3.1, 4.1, 5.1]) + + metrics = get_metrics(targets, predictions, "regression") + + assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} + assert metrics["MAE"] == pytest.approx(0.1, abs=1e-4) + assert metrics["RMSE"] == pytest.approx(0.1, abs=1e-4) + assert metrics["R2"] == pytest.approx(0.995, abs=1e-4) + + +def test_classification_metrics(): + targets = np.array([0, 1, 0, 1, 0]) + # Probabilities for class 0 and 1 + predictions = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]]) + + metrics = get_metrics(targets, predictions, "classification") + + assert set(metrics.keys()) == {"accuracy", "balanced_accuracy", "F1", "ROCAUC"} + assert metrics["accuracy"] == 1.0 + assert metrics["balanced_accuracy"] == 1.0 + assert metrics["F1"] == 1.0 + assert metrics["ROCAUC"] == 1.0 + + +def test_nan_handling(): + targets = np.array([1.0, np.nan, 3.0, 4.0]) + predictions = np.array([1.1, 2.1, np.nan, 4.1]) + + metrics = get_metrics(targets, predictions, "regression") + assert not np.isnan(metrics["MAE"]) + assert not np.isnan(metrics["RMSE"]) + assert not np.isnan(metrics["R2"]) + + +def test_pandas_input(): + targets = pd.Series([1.0, 2.0, 3.0]) + predictions = pd.Series([1.1, 2.1, 3.1]) + + metrics = get_metrics(targets, predictions, "regression") + assert set(metrics.keys()) == {"MAE", "RMSE", "R2"} + + +def test_precision(): + targets = np.array([1.0, 2.0, 3.0]) + predictions = np.array([1.12345, 2.12345, 3.12345]) + + metrics = get_metrics(targets, predictions, "regression", prec=2) + assert all(len(str(v).split(".")[-1]) <= 2 for v in metrics.values()) + + +def test_invalid_type(): + targets = np.array([1.0, 2.0]) + predictions = np.array([1.1, 2.1]) + + with pytest.raises(ValueError, match="Invalid task type: invalid_type"): + get_metrics(targets, predictions, "invalid_type") + + +def test_mismatched_shapes(): + targets = np.array([0, 1, 0]) + predictions = np.array([[0.9, 0.1], [0.1, 0.9]]) # Wrong shape + + with pytest.raises(ValueError): # noqa: PT011 + get_metrics(targets, predictions, "classification") From 4cc130d71f1859033bdca9ce7da50069918844fa Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 13 Apr 2025 11:30:17 -0400 Subject: [PATCH 05/12] fea: swap cgcnn to do the embedding inside the forward pass --- aviary/cgcnn/data.py | 59 ++++++++------------------------------- aviary/cgcnn/model.py | 30 +++++++++++++++++--- examples/cgcnn-example.py | 30 ++++++++++---------- tests/test_cgcnn.py | 54 +++++++++++++++++------------------ 4 files changed, 79 insertions(+), 94 deletions(-) diff --git a/aviary/cgcnn/data.py b/aviary/cgcnn/data.py index d426e260..7d731e2e 100644 --- a/aviary/cgcnn/data.py +++ b/aviary/cgcnn/data.py @@ -1,5 +1,4 @@ import itertools -import json from collections.abc import Sequence from functools import cache from typing import Any @@ -12,8 +11,6 @@ from torch.utils.data import Dataset from tqdm import tqdm -from aviary import PKG_DIR - class CrystalGraphData(Dataset): """Dataset class for the CGCNN structure model.""" @@ -22,13 +19,10 @@ def __init__( self, df: pd.DataFrame, task_dict: dict[str, str], - elem_embedding: str = "cgcnn92", structure_col: str = "structure", identifiers: Sequence[str] = (), - radius: float = 5, + radius_cutoff: float = 5, max_num_nbr: int = 12, - dmin: float = 0, - step: float = 0.2, ): """Featurize crystal structures into neighborhood graphs with this data class for CGCNN. @@ -36,44 +30,21 @@ def __init__( Args: df (pd.Dataframe): Pandas dataframe holding input and target values. task_dict ({target: task}): task dict for multi-task learning - elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16, - onehot112 or path to a file with custom element embeddings. - Defaults to matscholar200. structure_col (str, optional): df column holding pymatgen Structure objects as input. identifiers (list[str], optional): df columns for distinguishing data points. Will be copied over into the model's output CSV. Defaults to (). - radius (float, optional): Cut-off radius for neighborhood. Defaults to 5. + radius_cutoff (float, optional): Cut-off radius for neighborhood. + Defaults to 5. max_num_nbr (int, optional): maximum number of neighbors to consider. Defaults to 12. - dmin (float, optional): minimum distance in Gaussian basis. Defaults to 0. - step (float, optional): increment size of Gaussian basis. Defaults to 0.2. """ self.task_dict = task_dict self.identifiers = list(identifiers) - self.radius = radius + self.radius_cutoff = radius_cutoff self.max_num_nbr = max_num_nbr - if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"): - elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json" - - with open(elem_embedding) as file: - self.elem_features = json.load(file) - - for key, value in self.elem_features.items(): - self.elem_features[key] = np.array(value, dtype=float) - if not hasattr(self, "elem_emb_len"): - self.elem_emb_len = len(value) - elif self.elem_emb_len != len(value): - raise ValueError( - f"Element embedding length mismatch: len({key})=" - f"{len(value)}, expected {self.elem_emb_len}" - ) - - self.gaussian_dist_func = GaussianDistance(dmin=dmin, dmax=radius, step=step) - self.nbr_fea_dim = self.gaussian_dist_func.embedding_size - self.df = df self.structure_col = structure_col @@ -84,7 +55,7 @@ def __init__( self.df[structure_col].items(), total=len(df), desc=desc, disable=None ): self_idx, nbr_idx, _ = get_structure_neighbor_info( - struct, radius, max_num_nbr + struct, self.radius_cutoff, self.max_num_nbr ) material_ids = [idx, *self.df.loc[idx][self.identifiers]] if 0 in (len(self_idx), len(nbr_idx)): @@ -140,16 +111,10 @@ def __getitem__(self, idx: int): material_ids = [self.df.index[idx], *row[self.identifiers]] # atom features for disordered sites - site_atoms = [atom.species.as_dict() for atom in struct] - atom_features = np.vstack( - [ - np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0) - for site in site_atoms - ] - ) + atom_features = [atom.specie.Z for atom in struct] self_idx, nbr_idx, nbr_dist = get_structure_neighbor_info( - struct, self.radius, self.max_num_nbr + struct, self.radius_cutoff, self.max_num_nbr ) if len(self_idx) == 0: @@ -161,9 +126,7 @@ def __getitem__(self, idx: int): if set(self_idx) != set(range(len(struct))): raise ValueError(f"At least one atom in {material_ids} is isolated") - nbr_dist = self.gaussian_dist_func.expand(nbr_dist) - - atom_fea_t = Tensor(atom_features) + atom_fea_t = LongTensor(atom_features) nbr_dist_t = Tensor(nbr_dist) self_idx_t = LongTensor(self_idx) nbr_idx_t = LongTensor(nbr_idx) @@ -278,7 +241,7 @@ def __init__( "Max radii below minimum radii + step size - please increase dmax." ) - self.filter = np.arange(dmin, dmax + step, step) + self.filter = torch.arange(dmin, dmax + step, step) self.embedding_size = len(self.filter) if var is None: @@ -296,9 +259,9 @@ def expand(self, distances: np.ndarray) -> np.ndarray: np.ndarray: Expanded distance matrix with the last dimension of length len(self.filter) """ - distances = np.array(distances) + distances = torch.tensor(distances) - return np.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2) + return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2) def get_structure_neighbor_info( diff --git a/aviary/cgcnn/model.py b/aviary/cgcnn/model.py index e96389cb..ff62ef99 100644 --- a/aviary/cgcnn/model.py +++ b/aviary/cgcnn/model.py @@ -5,9 +5,11 @@ from pymatgen.util.due import Doi, due from torch import LongTensor, Tensor, nn +from aviary.cgcnn.data import GaussianDistance from aviary.core import BaseModelClass from aviary.networks import SimpleNetwork from aviary.scatter import scatter_reduce +from aviary.utils import get_element_embedding @due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model") @@ -25,8 +27,10 @@ def __init__( self, robust: bool, n_targets: Sequence[int], - elem_emb_len: int, - nbr_fea_len: int, + elem_embedding: str = "cgcnn92", + radius_cutoff: float = 5.0, + radius_min: float = 0.0, + radius_step: float = 0.2, elem_fea_len: int = 64, n_graph: int = 4, h_fea_len: int = 128, @@ -42,8 +46,15 @@ def __init__( (uncertainty inherent to the sample) which can be used with a robust loss function to attenuate the weighting of uncertain samples. n_targets (list[int]): Number of targets to train on - elem_emb_len (int): Number of atom features in the input. - nbr_fea_len (int): Number of bond features. + elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16, + onehot112 or path to a file with custom element embeddings. + Defaults to matscholar200. + radius_cutoff (float, optional): Cut-off radius for neighborhood. + Defaults to 5. + radius_min (float, optional): minimum distance in Gaussian basis. + Defaults to 0. + radius_step (float, optional): increment size of Gaussian basis. + Defaults to 0.2. elem_fea_len (int, optional): Number of hidden atom features in the convolutional layers. Defaults to 64. n_graph (int, optional): Number of convolutional layers. Defaults to 4. @@ -57,6 +68,14 @@ def __init__( """ super().__init__(robust=robust, **kwargs) + self.elem_embedding = get_element_embedding(elem_embedding) + elem_emb_len = self.elem_embedding.weight.shape[1] + + self.gaussian_dist_func = GaussianDistance( + dmin=radius_min, dmax=radius_cutoff, step=radius_step + ) + nbr_fea_len = self.gaussian_dist_func.embedding_size + desc_dict = { "elem_emb_len": elem_emb_len, "nbr_fea_len": nbr_fea_len, @@ -107,6 +126,9 @@ def forward( Returns: tuple[Tensor, ...]: tuple of predictions for all targets """ + nbr_fea = self.gaussian_dist_func.expand(nbr_fea) + atom_fea = self.elem_embedding(atom_fea) + atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx) crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean") diff --git a/examples/cgcnn-example.py b/examples/cgcnn-example.py index aeb4cf57..7b97d288 100644 --- a/examples/cgcnn-example.py +++ b/examples/cgcnn-example.py @@ -89,24 +89,18 @@ def main( task_dict = dict(zip(targets, tasks, strict=False)) loss_dict = dict(zip(targets, losses, strict=False)) - dist_dict = { - "radius": radius, - "max_num_nbr": max_num_nbr, - "dmin": dmin, - "step": step, - } - # NOTE make sure to use dense datasets, here do not use the default na # as they can clash with "NaN" which is a valid material df = pd.read_json(data_path) df["structure"] = df.structure.map(Structure.from_dict) dataset = CrystalGraphData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict + df=df, + task_dict=task_dict, + max_num_nbr=max_num_nbr, + radius_cutoff=radius, ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - nbr_fea_len = dataset.nbr_fea_dim train_idx = list(range(len(dataset))) @@ -119,7 +113,10 @@ def main( print(f"using independent test set: {test_path}") test_set = CrystalGraphData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict + df=df, + task_dict=task_dict, + max_num_nbr=max_num_nbr, + radius_cutoff=radius, ) test_set = torch.utils.data.Subset(test_set, range(len(test_set))) elif test_size == 0: @@ -140,7 +137,10 @@ def main( print(f"using independent validation set: {val_path}") val_set = CrystalGraphData( - df=df, elem_embedding=elem_embedding, task_dict=task_dict, **dist_dict + df=df, + task_dict=task_dict, + max_num_nbr=max_num_nbr, + radius_cutoff=radius, ) val_set = torch.utils.data.Subset(val_set, range(len(val_set))) elif val_size == 0 and evaluate: @@ -192,8 +192,10 @@ def main( "task_dict": task_dict, "robust": robust, "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "nbr_fea_len": nbr_fea_len, + "elem_embedding": elem_embedding, + "radius_cutoff": radius, + "radius_min": dmin, + "radius_step": step, "elem_fea_len": elem_fea_len, "n_graph": n_graph, "h_fea_len": h_fea_len, diff --git a/tests/test_cgcnn.py b/tests/test_cgcnn.py index e79d2fb6..4f99cca8 100644 --- a/tests/test_cgcnn.py +++ b/tests/test_cgcnn.py @@ -20,6 +20,11 @@ def base_config(): "log": False, "sample": 1, "test_size": 0.2, + "radius": 5, + "max_num_nbr": 12, + "dmin": 0, + "step": 0.2, + "patience": None, } @@ -63,12 +68,11 @@ def test_cgcnn_regression( dataset = CrystalGraphData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, + max_num_nbr=base_config["max_num_nbr"], + radius_cutoff=base_config["radius"], ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - nbr_fea_len = dataset.nbr_fea_dim train_idx = list(range(len(dataset))) train_idx, test_idx = split( @@ -112,8 +116,10 @@ def test_cgcnn_regression( "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "nbr_fea_len": nbr_fea_len, + "elem_embedding": base_config["elem_embedding"], + "radius_cutoff": base_config["radius"], + "radius_min": base_config["dmin"], + "radius_step": base_config["step"], **model_architecture, } @@ -123,6 +129,7 @@ def test_cgcnn_regression( run_id=base_config["run_id"], ensemble_folds=base_config["ensemble"], epochs=epochs, + patience=base_config["patience"], train_loader=train_loader, val_loader=val_loader, log=base_config["log"], @@ -154,12 +161,12 @@ def test_cgcnn_regression( targets = results_dict[target_name]["targets"] y_ens = np.mean(preds, axis=0) - mae, rmse, r2 = get_metrics(targets, y_ens, task).values() + metrics = get_metrics(targets, y_ens, task) assert len(targets) == len(test_set) == len(test_idx) - assert r2 > 0.7 - assert mae < 150 - assert rmse < 300 + assert metrics["R2"] > 0.7 + assert metrics["MAE"] < 150 + assert metrics["RMSE"] < 300 def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, training_config): @@ -174,30 +181,20 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin dataset = CrystalGraphData( df=df_matbench_phonons, - elem_embedding=base_config["elem_embedding"], task_dict=task_dict, + max_num_nbr=base_config["max_num_nbr"], + radius_cutoff=base_config["radius"], ) n_targets = dataset.n_targets - elem_emb_len = dataset.elem_emb_len - nbr_fea_len = dataset.nbr_fea_dim train_idx = list(range(len(dataset))) - - print(f"using {base_config['test_size']} of training set as test set") train_idx, test_idx = split( train_idx, random_state=base_config["data_seed"], test_size=base_config["test_size"], ) test_set = torch.utils.data.Subset(dataset, test_idx) - - print("No validation set used, using test set for evaluation purposes") - # NOTE that when using this option care must be taken not to - # peak at the test-set. The only valid model to use is the one - # obtained after the final epoch where the epoch count is - # decided in advance of the experiment. val_set = test_set - train_set = torch.utils.data.Subset(dataset, train_idx[0 :: base_config["sample"]]) data_params = { @@ -232,8 +229,10 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin "task_dict": task_dict, "robust": base_config["robust"], "n_targets": n_targets, - "elem_emb_len": elem_emb_len, - "nbr_fea_len": nbr_fea_len, + "elem_embedding": base_config["elem_embedding"], + "radius_cutoff": base_config["radius"], + "radius_min": base_config["dmin"], + "radius_step": base_config["step"], **model_architecture, } @@ -243,6 +242,7 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin run_id=base_config["run_id"], ensemble_folds=base_config["ensemble"], epochs=epochs, + patience=base_config["patience"], train_loader=train_loader, val_loader=val_loader, log=base_config["log"], @@ -273,14 +273,12 @@ def test_cgcnn_clf(df_matbench_phonons, base_config, model_architecture, trainin logits = results_dict["phdos_clf"]["logits"] targets = results_dict["phdos_clf"]["targets"] - # calculate metrics and errors with associated errors for ensembles ens_logits = np.mean(logits, axis=0) - - ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values() + metrics = get_metrics(targets, ens_logits, task) assert len(targets) == len(test_set) == len(test_idx) - assert ens_acc > 0.85 - assert ens_roc_auc > 0.9 + assert metrics["accuracy"] > 0.85 + assert metrics["ROCAUC"] > 0.9 if __name__ == "__main__": From 17f279b288ff26d35229ac85966828195ec32faa Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 13 Apr 2025 12:01:19 -0400 Subject: [PATCH 06/12] clean: clear notebook outputs --- .pre-commit-config.yaml | 2 +- examples/notebooks/Roost.ipynb | 404 +--------------------------- examples/notebooks/Wren.ipynb | 379 +------------------------- examples/notebooks/Wrenformer.ipynb | 110 ++------ 4 files changed, 27 insertions(+), 868 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc43aece..5f627e19 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,4 +45,4 @@ repos: rev: 0.8.1 hooks: - id: nbstripout - args: [--drop-empty-cells, --keep-output] + args: [--drop-empty-cells] diff --git a/examples/notebooks/Roost.ipynb b/examples/notebooks/Roost.ipynb index 6a9f7989..b95a83f2 100644 --- a/examples/notebooks/Roost.ipynb +++ b/examples/notebooks/Roost.ipynb @@ -56,62 +56,7 @@ "execution_count": null, "id": "2", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n" - ] - } - ], + "outputs": [], "source": [ "with gzip.open(\"taata.json.gz\", \"r\") as fin:\n", " json_bytes = fin.read()\n", @@ -168,7 +113,7 @@ "\n", "ensemble = 1\n", "run_id = 1\n", - "epochs = 100\n", + "epochs = 1\n", "log = False\n", "\n", "# NOTE setting workers to zero means that the data is loaded in the main\n", @@ -204,330 +149,7 @@ "execution_count": null, "id": "4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "using 0.2 of training set as test set\n", - "No validation set used, using test set for evaluation purposes\n", - "Total Number of Trainable Parameters: 973,777\n", - "Dummy MAE: 1.2757\n", - "Epoch: [0/99]\n", - " train: E_vasp_per_atom N 10 MAE 1.27 Loss 1.12 RMSE 1.59 \n", - " evaluate: E_vasp_per_atom N 1 MAE 1.29 Loss 1.13 RMSE 1.59 \n", - "Epoch: [1/99]\n", - " train: E_vasp_per_atom N 10 MAE 1.25 Loss 1.11 RMSE 1.59 \n", - " evaluate: E_vasp_per_atom N 1 MAE 1.25 Loss 1.10 RMSE 1.55 \n", - "Epoch: [2/99]\n", - " train: E_vasp_per_atom N 10 MAE 1.17 Loss 1.03 RMSE 1.50 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.98 Loss 0.87 RMSE 1.30 \n", - "Epoch: [3/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.85 Loss 0.74 RMSE 1.23 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.62 Loss 0.49 RMSE 0.87 \n", - "Epoch: [4/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.59 Loss 0.40 RMSE 0.89 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.44 Loss 0.13 RMSE 0.59 \n", - "Epoch: [5/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.45 Loss 0.10 RMSE 0.69 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.39 Loss -0.06 RMSE 0.54 \n", - "Epoch: [6/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.44 Loss 0.03 RMSE 0.68 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.37 Loss -0.10 RMSE 0.50 \n", - "Epoch: [7/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.40 Loss -0.03 RMSE 0.63 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.36 Loss -0.14 RMSE 0.50 \n", - "Epoch: [8/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.38 Loss -0.10 RMSE 0.61 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.38 Loss -0.07 RMSE 0.49 \n", - "Epoch: [9/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.38 Loss -0.09 RMSE 0.61 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.36 Loss -0.13 RMSE 0.46 \n", - "Epoch: [10/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.36 Loss -0.15 RMSE 0.57 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.31 RMSE 0.42 \n", - "Epoch: [11/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.35 Loss -0.20 RMSE 0.56 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.32 Loss -0.26 RMSE 0.42 \n", - "Epoch: [12/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.36 Loss -0.17 RMSE 0.57 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.33 RMSE 0.43 \n", - "Epoch: [13/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.35 Loss -0.18 RMSE 0.56 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.32 RMSE 0.42 \n", - "Epoch: [14/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.38 RMSE 0.41 \n", - "Epoch: [15/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.31 Loss -0.32 RMSE 0.41 \n", - "Epoch: [16/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.33 RMSE 0.41 \n", - "Epoch: [17/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.53 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.42 \n", - "Epoch: [18/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.26 RMSE 0.53 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.40 \n", - "Epoch: [19/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.31 RMSE 0.51 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.39 RMSE 0.39 \n", - "Epoch: [20/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.29 RMSE 0.52 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.39 \n", - "Epoch: [21/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.50 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.35 RMSE 0.40 \n", - "Epoch: [22/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.30 RMSE 0.50 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.34 Loss -0.20 RMSE 0.42 \n", - "Epoch: [23/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.35 Loss -0.20 RMSE 0.51 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.38 RMSE 0.41 \n", - "Epoch: [24/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.28 RMSE 0.50 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.38 RMSE 0.41 \n", - "Epoch: [25/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.48 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.34 RMSE 0.42 \n", - "Epoch: [26/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.31 RMSE 0.48 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.37 RMSE 0.39 \n", - "Epoch: [27/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.48 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.37 RMSE 0.39 \n", - "Epoch: [28/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.47 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.37 RMSE 0.39 \n", - "Epoch: [29/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.47 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.41 RMSE 0.38 \n", - "Epoch: [30/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.34 RMSE 0.46 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.33 RMSE 0.39 \n", - "Epoch: [31/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.33 RMSE 0.46 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.44 RMSE 0.39 \n", - "Epoch: [32/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.31 RMSE 0.46 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.41 RMSE 0.38 \n", - "Epoch: [33/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.31 Loss -0.32 RMSE 0.45 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.41 RMSE 0.41 \n", - "Epoch: [34/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.30 RMSE 0.45 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.35 Loss -0.20 RMSE 0.42 \n", - "Epoch: [35/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.32 Loss -0.30 RMSE 0.45 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.37 RMSE 0.41 \n", - "Epoch: [36/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.33 Loss -0.27 RMSE 0.47 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.40 \n", - "Epoch: [37/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.35 RMSE 0.44 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.37 \n", - "Epoch: [38/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.42 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.41 RMSE 0.36 \n", - "Epoch: [39/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.36 RMSE 0.43 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.39 \n", - "Epoch: [40/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.42 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.36 \n", - "Epoch: [41/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.36 \n", - "Epoch: [42/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.44 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.40 RMSE 0.36 \n", - "Epoch: [43/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.34 RMSE 0.42 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.31 Loss -0.33 RMSE 0.43 \n", - "Epoch: [44/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.37 RMSE 0.41 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.36 \n", - "Epoch: [45/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.41 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.33 RMSE 0.37 \n", - "Epoch: [46/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.37 RMSE 0.41 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.44 RMSE 0.40 \n", - "Epoch: [47/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.41 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.43 RMSE 0.36 \n", - "Epoch: [48/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.36 \n", - "Epoch: [49/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.45 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.36 \n", - "Epoch: [50/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.46 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.35 \n", - "Epoch: [51/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.45 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.51 RMSE 0.35 \n", - "Epoch: [52/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.37 \n", - "Epoch: [53/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.46 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.45 RMSE 0.36 \n", - "Epoch: [54/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.45 RMSE 0.39 \n", - "Epoch: [55/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.36 \n", - "Epoch: [56/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.41 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.40 RMSE 0.36 \n", - "Epoch: [57/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.41 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.45 RMSE 0.35 \n", - "Epoch: [58/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.36 \n", - "Epoch: [59/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.45 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.34 \n", - "Epoch: [60/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", - "Epoch: [61/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.39 \n", - "Epoch: [62/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.42 RMSE 0.36 \n", - "Epoch: [63/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.39 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.31 Loss -0.28 RMSE 0.44 \n", - "Epoch: [64/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.36 RMSE 0.42 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.35 \n", - "Epoch: [65/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.29 Loss -0.37 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.42 RMSE 0.36 \n", - "Epoch: [66/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.30 Loss -0.35 RMSE 0.41 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.36 RMSE 0.42 \n", - "Epoch: [67/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.42 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.38 \n", - "Epoch: [68/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.33 RMSE 0.37 \n", - "Epoch: [69/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.50 RMSE 0.36 \n", - "Epoch: [70/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.48 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.50 RMSE 0.35 \n", - "Epoch: [71/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.37 \n", - "Epoch: [72/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.36 \n", - "Epoch: [73/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", - "Epoch: [74/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.37 \n", - "Epoch: [75/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.53 RMSE 0.35 \n", - "Epoch: [76/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.52 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.35 \n", - "Epoch: [77/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.35 \n", - "Epoch: [78/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.48 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.51 RMSE 0.37 \n", - "Epoch: [79/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.50 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.54 RMSE 0.35 \n", - "Epoch: [80/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.52 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.38 \n", - "Epoch: [81/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.50 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.54 RMSE 0.35 \n", - "Epoch: [82/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.51 RMSE 0.36 \n", - "Epoch: [83/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.48 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.53 RMSE 0.35 \n", - "Epoch: [84/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.46 RMSE 0.35 \n", - "Epoch: [85/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.41 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.47 RMSE 0.38 \n", - "Epoch: [86/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.50 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.48 RMSE 0.35 \n", - "Epoch: [87/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.36 \n", - "Epoch: [88/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.45 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.36 \n", - "Epoch: [89/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.47 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.52 RMSE 0.36 \n", - "Epoch: [90/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.48 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", - "Epoch: [91/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.51 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.49 RMSE 0.35 \n", - "Epoch: [92/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.50 RMSE 0.36 \n", - "Epoch: [93/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.38 RMSE 0.36 \n", - "Epoch: [94/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.53 RMSE 0.35 \n", - "Epoch: [95/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.51 RMSE 0.37 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.53 RMSE 0.36 \n", - "Epoch: [96/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.49 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.50 RMSE 0.34 \n", - "Epoch: [97/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.28 Loss -0.43 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.42 RMSE 0.40 \n", - "Epoch: [98/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.27 Loss -0.44 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.53 RMSE 0.35 \n", - "Epoch: [99/99]\n", - " train: E_vasp_per_atom N 10 MAE 0.26 Loss -0.47 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.44 RMSE 0.36 \n" - ] - }, - { - "ename": "TypeError", - "evalue": "results_multitask() got an unexpected keyword argument 'test_set'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 73\u001b[0m\n\u001b[1;32m 53\u001b[0m train_ensemble(\n\u001b[1;32m 54\u001b[0m model_class\u001b[38;5;241m=\u001b[39mRoost,\n\u001b[1;32m 55\u001b[0m model_name\u001b[38;5;241m=\u001b[39mmodel_name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 65\u001b[0m loss_dict\u001b[38;5;241m=\u001b[39mloss_dict,\n\u001b[1;32m 66\u001b[0m )\n\u001b[1;32m 68\u001b[0m test_loader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 69\u001b[0m test_set,\n\u001b[1;32m 70\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m{\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_params, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m \u001b[38;5;241m*\u001b[39m data_params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshuffle\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m},\n\u001b[1;32m 71\u001b[0m )\n\u001b[0;32m---> 73\u001b[0m roost_results_dict \u001b[38;5;241m=\u001b[39m \u001b[43mresults_multitask\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mRoost\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[43m \u001b[49m\u001b[43mrun_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mensemble_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensemble\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_set\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[43m \u001b[49m\u001b[43mrobust\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrobust\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 81\u001b[0m \u001b[43m \u001b[49m\u001b[43mtask_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 83\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcheckpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mTypeError\u001b[0m: results_multitask() got an unexpected keyword argument 'test_set'" - ] - } - ], + "outputs": [], "source": [ "torch.manual_seed(0) # ensure reproducible results\n", "\n", @@ -601,25 +223,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "------------Evaluate model on Test Set------------\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "Evaluating Model\n", - "\n", - "Task: target_name='E_vasp_per_atom' on test set\n", - "Model Performance Metrics:\n", - "R2 Score: 0.9494 \n", - "MAE: 0.2701\n", - "RMSE: 0.3576\n" - ] - } - ], + "outputs": [], "source": [ "test_loader = DataLoader(\n", " test_set,\n", diff --git a/examples/notebooks/Wren.ipynb b/examples/notebooks/Wren.ipynb index 72364bb3..01b32c32 100644 --- a/examples/notebooks/Wren.ipynb +++ b/examples/notebooks/Wren.ipynb @@ -56,62 +56,7 @@ "execution_count": null, "id": "2", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n" - ] - } - ], + "outputs": [], "source": [ "with gzip.open(\"taata.json.gz\", \"r\") as fin:\n", " json_bytes = fin.read()\n", @@ -168,7 +113,7 @@ "\n", "ensemble = 1\n", "run_id = 1\n", - "epochs = 100\n", + "epochs = 1\n", "log = False\n", "\n", "# NOTE setting workers to zero means that the data is loaded in the main\n", @@ -204,298 +149,7 @@ "execution_count": null, "id": "4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "using 0.2 of training set as test set\n", - "No validation set used, using test set for evaluation purposes\n", - "Total Number of Trainable Parameters: 1,154,330\n", - "Dummy MAE: 0.9670\n", - "Epoch: [0/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.95 Loss 1.11 RMSE 1.19 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.84 Loss 0.97 RMSE 1.07 \n", - "Epoch: [1/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.54 Loss 0.52 RMSE 0.73 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.36 Loss 0.18 RMSE 0.50 \n", - "Epoch: [2/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.35 Loss 0.13 RMSE 0.48 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.34 Loss 0.07 RMSE 0.47 \n", - "Epoch: [3/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.32 Loss 0.01 RMSE 0.44 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.32 Loss 0.01 RMSE 0.44 \n", - "Epoch: [4/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.31 Loss -0.02 RMSE 0.43 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.34 Loss 0.08 RMSE 0.47 \n", - "Epoch: [5/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.30 Loss -0.04 RMSE 0.42 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.06 RMSE 0.41 \n", - "Epoch: [6/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.28 Loss -0.12 RMSE 0.40 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.10 RMSE 0.41 \n", - "Epoch: [7/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.27 Loss -0.15 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.10 RMSE 0.41 \n", - "Epoch: [8/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.27 Loss -0.15 RMSE 0.39 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.30 Loss -0.07 RMSE 0.43 \n", - "Epoch: [9/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.26 Loss -0.19 RMSE 0.38 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.16 RMSE 0.40 \n", - "Epoch: [10/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.25 Loss -0.25 RMSE 0.36 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.18 RMSE 0.39 \n", - "Epoch: [11/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.25 Loss -0.27 RMSE 0.36 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.10 RMSE 0.39 \n", - "Epoch: [12/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.28 RMSE 0.35 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.21 RMSE 0.37 \n", - "Epoch: [13/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.31 RMSE 0.35 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.26 Loss -0.22 RMSE 0.37 \n", - "Epoch: [14/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.31 RMSE 0.34 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.28 Loss -0.10 RMSE 0.40 \n", - "Epoch: [15/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.30 RMSE 0.34 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss -0.07 RMSE 0.41 \n", - "Epoch: [16/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.24 Loss -0.30 RMSE 0.34 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.18 RMSE 0.38 \n", - "Epoch: [17/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.40 RMSE 0.33 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.27 Loss -0.13 RMSE 0.37 \n", - "Epoch: [18/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.41 RMSE 0.32 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.27 RMSE 0.36 \n", - "Epoch: [19/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.42 RMSE 0.32 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.27 RMSE 0.36 \n", - "Epoch: [20/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.22 Loss -0.40 RMSE 0.32 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.30 RMSE 0.35 \n", - "Epoch: [21/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.21 Loss -0.46 RMSE 0.31 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.32 RMSE 0.34 \n", - "Epoch: [22/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.21 Loss -0.48 RMSE 0.31 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.29 RMSE 0.35 \n", - "Epoch: [23/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.52 RMSE 0.30 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.25 Loss -0.17 RMSE 0.35 \n", - "Epoch: [24/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.53 RMSE 0.30 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.33 RMSE 0.34 \n", - "Epoch: [25/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.53 RMSE 0.30 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.30 RMSE 0.34 \n", - "Epoch: [26/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.53 RMSE 0.30 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.34 RMSE 0.33 \n", - "Epoch: [27/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.55 RMSE 0.29 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.31 RMSE 0.34 \n", - "Epoch: [28/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.56 RMSE 0.29 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.29 RMSE 0.33 \n", - "Epoch: [29/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.58 RMSE 0.29 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.34 RMSE 0.33 \n", - "Epoch: [30/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.20 Loss -0.55 RMSE 0.29 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.34 RMSE 0.33 \n", - "Epoch: [31/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.57 RMSE 0.29 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.35 RMSE 0.33 \n", - "Epoch: [32/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.61 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.31 RMSE 0.33 \n", - "Epoch: [33/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.62 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss -0.28 RMSE 0.33 \n", - "Epoch: [34/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.19 Loss -0.60 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.32 RMSE 0.32 \n", - "Epoch: [35/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.62 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.29 Loss 0.11 RMSE 0.40 \n", - "Epoch: [36/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.62 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.24 Loss -0.21 RMSE 0.35 \n", - "Epoch: [37/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.64 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.34 RMSE 0.32 \n", - "Epoch: [38/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.63 RMSE 0.28 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.35 RMSE 0.33 \n", - "Epoch: [39/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.66 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", - "Epoch: [40/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.69 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.34 RMSE 0.33 \n", - "Epoch: [41/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.18 Loss -0.68 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", - "Epoch: [42/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.70 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.33 RMSE 0.32 \n", - "Epoch: [43/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.35 RMSE 0.32 \n", - "Epoch: [44/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.72 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", - "Epoch: [45/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.32 \n", - "Epoch: [46/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.76 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.35 RMSE 0.31 \n", - "Epoch: [47/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.70 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.39 RMSE 0.31 \n", - "Epoch: [48/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.71 RMSE 0.27 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.31 RMSE 0.33 \n", - "Epoch: [49/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.73 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.38 RMSE 0.31 \n", - "Epoch: [50/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.79 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.31 \n", - "Epoch: [51/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.36 RMSE 0.32 \n", - "Epoch: [52/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.77 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.36 RMSE 0.31 \n", - "Epoch: [53/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.78 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.31 \n", - "Epoch: [54/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.79 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.33 RMSE 0.32 \n", - "Epoch: [55/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.76 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.31 \n", - "Epoch: [56/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.82 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.39 RMSE 0.31 \n", - "Epoch: [57/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.81 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.40 RMSE 0.31 \n", - "Epoch: [58/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.81 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.43 RMSE 0.32 \n", - "Epoch: [59/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.73 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.42 RMSE 0.31 \n", - "Epoch: [60/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.82 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.38 RMSE 0.31 \n", - "Epoch: [61/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.79 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.41 RMSE 0.31 \n", - "Epoch: [62/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.30 \n", - "Epoch: [63/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.16 Loss -0.85 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.39 RMSE 0.31 \n", - "Epoch: [64/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.31 \n", - "Epoch: [65/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.86 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.36 RMSE 0.30 \n", - "Epoch: [66/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.42 RMSE 0.30 \n", - "Epoch: [67/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.40 RMSE 0.31 \n", - "Epoch: [68/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.88 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.34 RMSE 0.30 \n", - "Epoch: [69/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.88 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.22 Loss -0.23 RMSE 0.33 \n", - "Epoch: [70/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.89 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.23 Loss 0.01 RMSE 0.33 \n", - "Epoch: [71/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.17 Loss -0.74 RMSE 0.26 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.32 RMSE 0.31 \n", - "Epoch: [72/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.92 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.46 RMSE 0.30 \n", - "Epoch: [73/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.94 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.28 RMSE 0.31 \n", - "Epoch: [74/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.90 RMSE 0.25 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.37 RMSE 0.30 \n", - "Epoch: [75/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.87 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.28 RMSE 0.31 \n", - "Epoch: [76/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.90 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.41 RMSE 0.30 \n", - "Epoch: [77/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.91 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.46 RMSE 0.30 \n", - "Epoch: [78/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.94 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.32 RMSE 0.30 \n", - "Epoch: [79/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.98 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.29 RMSE 0.31 \n", - "Epoch: [80/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.96 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.21 Loss -0.24 RMSE 0.31 \n", - "Epoch: [81/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.91 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.38 RMSE 0.31 \n", - "Epoch: [82/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.96 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.38 RMSE 0.31 \n", - "Epoch: [83/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.94 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.38 RMSE 0.31 \n", - "Epoch: [84/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.97 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.34 RMSE 0.30 \n", - "Epoch: [85/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.99 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.31 RMSE 0.31 \n", - "Epoch: [86/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.00 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.36 RMSE 0.30 \n", - "Epoch: [87/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.15 Loss -0.92 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.19 Loss -0.41 RMSE 0.30 \n", - "Epoch: [88/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.02 RMSE 0.23 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.30 RMSE 0.30 \n", - "Epoch: [89/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -0.95 RMSE 0.24 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.31 RMSE 0.30 \n", - "Epoch: [90/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.01 RMSE 0.23 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.19 Loss -0.35 RMSE 0.30 \n", - "Epoch: [91/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.01 RMSE 0.23 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.19 Loss -0.38 RMSE 0.30 \n", - "Epoch: [92/99]\n", - " train: E_vasp_per_atom N 57 MAE 0.14 Loss -1.03 RMSE 0.23 \n", - " evaluate: E_vasp_per_atom N 1 MAE 0.20 Loss -0.32 RMSE 0.30 \n", - "Epoch: [93/99]\n" - ] - } - ], + "outputs": [], "source": [ "torch.manual_seed(0) # ensure reproducible results\n", "\n", @@ -569,32 +223,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "------------Evaluate model on Test Set------------\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "Evaluating Model\n" - ] - }, - { - "ename": "ValueError", - "evalue": "task_dict {'last phdos peak': 'regression'} of checkpoint resume='/Users/radical-rhys/Radical/aviary/models/wren-reg-test/checkpoint-r1.pth.tar' does not match provided task_dict={'E_vasp_per_atom': 'regression'}", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[16], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m test_loader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 2\u001b[0m test_set,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m{\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_params, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;241m64\u001b[39m \u001b[38;5;241m*\u001b[39m data_params[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshuffle\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m},\n\u001b[1;32m 4\u001b[0m )\n\u001b[0;32m----> 6\u001b[0m roost_results_dict \u001b[38;5;241m=\u001b[39m \u001b[43mresults_multitask\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mWren\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mrun_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrun_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mensemble_folds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mensemble\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mrobust\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrobust\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mtask_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43meval_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcheckpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_results\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Radical/aviary/aviary/utils.py:490\u001b[0m, in \u001b[0;36mresults_multitask\u001b[0;34m(model_class, model_name, run_id, ensemble_folds, test_loader, robust, task_dict, device, eval_type, print_results, save_results)\u001b[0m\n\u001b[1;32m 488\u001b[0m chkpt_task_dict \u001b[38;5;241m=\u001b[39m checkpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_params\u001b[39m\u001b[38;5;124m\"\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtask_dict\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 489\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chkpt_task_dict \u001b[38;5;241m!=\u001b[39m task_dict:\n\u001b[0;32m--> 490\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 491\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtask_dict \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mchkpt_task_dict\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m of checkpoint \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresume\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m does not match \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprovided \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtask_dict\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 493\u001b[0m )\n\u001b[1;32m 495\u001b[0m model: BaseModelClass \u001b[38;5;241m=\u001b[39m model_class(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheckpoint[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel_params\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 496\u001b[0m model\u001b[38;5;241m.\u001b[39mto(device)\n", - "\u001b[0;31mValueError\u001b[0m: task_dict {'last phdos peak': 'regression'} of checkpoint resume='/Users/radical-rhys/Radical/aviary/models/wren-reg-test/checkpoint-r1.pth.tar' does not match provided task_dict={'E_vasp_per_atom': 'regression'}" - ] - } - ], + "outputs": [], "source": [ "test_loader = DataLoader(\n", " test_set,\n", diff --git a/examples/notebooks/Wrenformer.ipynb b/examples/notebooks/Wrenformer.ipynb index 86780193..11703c61 100644 --- a/examples/notebooks/Wrenformer.ipynb +++ b/examples/notebooks/Wrenformer.ipynb @@ -55,62 +55,7 @@ "execution_count": null, "id": "2", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n", - "spglib: ssm_get_exact_positions failed.\n", - "spglib: get_bravais_exact_positions_and_lattice failed.\n" - ] - } - ], + "outputs": [], "source": [ "with gzip.open(\"taata.json.gz\", \"r\") as fin:\n", " json_bytes = fin.read()\n", @@ -129,7 +74,12 @@ "df = df[df.protostructure.map(count_wyckoff_positions) < 16]\n", "df[\"n_sites\"] = df.final_structure.map(len)\n", "df = df[df.n_sites < 64]\n", - "df = df[df.volume_per_atom < 500]" + "df = df[df.volume_per_atom < 500]\n", + "\n", + "# NOTE for roost we keep only the lowest lying structures for each composition\n", + "df = df.sort_values([\"protostructure\", \"E_vasp_per_atom\"]).drop_duplicates(\n", + " \"protostructure\", keep=\"first\"\n", + ")" ] }, { @@ -162,7 +112,7 @@ "\n", "ensemble = 1\n", "run_id = 1\n", - "epochs = 3\n", + "epochs = 1\n", "log = False\n", "\n", "# NOTE setting workers to zero means that the data is loaded in the main\n", @@ -198,38 +148,7 @@ "execution_count": null, "id": "4", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "using 0.2 of training set as test set\n", - "No validation set used, using test set for evaluation purposes\n", - "Total Number of Trainable Parameters: 5,166,658\n", - "Dummy MAE: 0.9223\n", - "Epoch: [0/2]\n", - " train: E_vasp_per_atom N 76 MAE 0.89 Loss 1.09 RMSE 1.13 \n", - " evaluate: E_vasp_per_atom N 2 MAE 0.71 Loss 0.83 RMSE 0.95 \n", - "Epoch: [1/2]\n", - " train: E_vasp_per_atom N 76 MAE 0.57 Loss 0.60 RMSE 0.78 \n", - " evaluate: E_vasp_per_atom N 2 MAE 0.53 Loss 0.51 RMSE 0.71 \n", - "Epoch: [2/2]\n", - " train: E_vasp_per_atom N 76 MAE 0.45 Loss 0.36 RMSE 0.62 \n", - " evaluate: E_vasp_per_atom N 2 MAE 0.37 Loss 0.19 RMSE 0.52 \n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "------------Evaluate model on Test Set------------\n", - "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", - "\n", - "Evaluating Model\n", - "\n", - "Task: target_name='E_vasp_per_atom' on test set\n", - "Model Performance Metrics:\n", - "R2 Score: 0.7842 \n", - "MAE: 0.3913\n", - "RMSE: 0.5500\n" - ] - } - ], + "outputs": [], "source": [ "torch.manual_seed(0) # ensure reproducible results\n", "\n", @@ -303,8 +222,15 @@ " restart_params=restart_params,\n", " model_params=model_params,\n", " loss_dict=loss_dict,\n", - ")\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "test_loader = df_to_in_mem_dataloader(\n", " test_df,\n", " batch_size=batch_size * 64,\n", From 1dd094bc113a898be2d9beb5b15a40262b8a98a5 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 19 Apr 2025 11:01:14 -0400 Subject: [PATCH 07/12] fea: swap to use hatch --- pyproject.toml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96a75b8c..f5b86eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools>=61.2"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "aviary" @@ -52,14 +52,14 @@ Repo = "https://github.com/CompRhys/aviary" test = ["matminer", "moyopy>=0.3.3", "pytest", "pytest-cov"] moyopy = ["moyopy>=0.3.3"] -[tool.setuptools.packages] -find = { include = ["aviary*"], exclude = ["tests*"] } +[tool.hatch.build.targets.wheel] +packages = ["aviary"] -[tool.setuptools.package-data] -aviary = ["**/**/*.json", "**/*.json"] - -[tool.distutils.bdist_wheel] -universal = true +[tool.hatch.build] +include = [ + "aviary/**/*.py", + "aviary/**/*.json", +] [tool.pytest.ini_options] testpaths = ["tests"] From 606667ea9f8dd0bd44543e7a82acc7bacbec287a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 19 Apr 2025 11:05:00 -0400 Subject: [PATCH 08/12] fix: uppercase LICENSE --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f5b86eec..f2b8bd45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "1.2.0" description = "A collection of machine learning models for materials discovery" authors = [{ name = "Rhys Goodall", email = "rhys.goodall@outlook.com" }] readme = "README.md" -license = { file = "license" } +license = { file = "LICENSE" } keywords = [ "Graph Neural Network", "Machine Learning", From da125aaa31c6b50628f472755f62ec2819d95800 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 19 Apr 2025 11:10:11 -0400 Subject: [PATCH 09/12] maint: uv install cpu torch --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b0cfc0ec..b9361334 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,7 @@ jobs: - name: Install dependencies run: | - pip install torch --index-url https://download.pytorch.org/whl/cpu + uv pip install torch --index-url https://download.pytorch.org/whl/cpu --system uv pip install .[test] --system - name: Run Tests From 09820781c9fcde76e17e110c48b8ef8e11a3b8fe Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 19 Apr 2025 11:11:50 -0400 Subject: [PATCH 10/12] maint: also run tests on windows latest --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b9361334..0aca90cf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-14] + os: [ubuntu-latest, macos-latest, windows-latest] version: - { python: "3.10", resolution: highest } - { python: "3.12", resolution: lowest-direct } From d6e423c7b5d9bdca170d34f62eddb12d60df0903 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 19 Apr 2025 11:12:59 -0400 Subject: [PATCH 11/12] maint: cancel concurrent tests --- .github/workflows/test.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0aca90cf..6400e11c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,6 +8,11 @@ on: paths: ["**/*.py", .github/workflows/test.yml] branches: [main] +concurrency: + # Cancel only on same PR number + group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }} + cancel-in-progress: true + jobs: tests: strategy: From 55543a3b998622d123d71766af3e04c8a2489479 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 19 Apr 2025 11:23:57 -0400 Subject: [PATCH 12/12] fix: avoid tensor copy warning --- aviary/cgcnn/data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/aviary/cgcnn/data.py b/aviary/cgcnn/data.py index 7d731e2e..1cd7958e 100644 --- a/aviary/cgcnn/data.py +++ b/aviary/cgcnn/data.py @@ -249,18 +249,16 @@ def __init__( self.var = var - def expand(self, distances: np.ndarray) -> np.ndarray: + def expand(self, distances: Tensor) -> Tensor: """Apply Gaussian distance filter to a numpy distance array. Args: distances (ArrayLike): A distance matrix of any shape. Returns: - np.ndarray: Expanded distance matrix with the last dimension of length + Tensor: Expanded distance matrix with the last dimension of length len(self.filter) """ - distances = torch.tensor(distances) - return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)