Skip to content

Swap to using nn.Embedding #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ on:
paths: ["**/*.py", .github/workflows/test.yml]
branches: [main]

concurrency:
# Cancel only on same PR number
group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
tests:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-14]
os: [ubuntu-latest, macos-latest, windows-latest]
version:
- { python: "3.10", resolution: highest }
- { python: "3.12", resolution: lowest-direct }
Expand All @@ -33,7 +38,7 @@ jobs:

- name: Install dependencies
run: |
pip install torch --index-url https://download.pytorch.org/whl/cpu
uv pip install torch --index-url https://download.pytorch.org/whl/cpu --system
uv pip install .[test] --system

- name: Run Tests
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ repos:
rev: 0.8.1
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
args: [--drop-empty-cells]
63 changes: 12 additions & 51 deletions aviary/cgcnn/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
import json
from collections.abc import Sequence
from functools import cache
from typing import Any
Expand All @@ -12,8 +11,6 @@
from torch.utils.data import Dataset
from tqdm import tqdm

from aviary import PKG_DIR


class CrystalGraphData(Dataset):
"""Dataset class for the CGCNN structure model."""
Expand All @@ -22,58 +19,32 @@ def __init__(
self,
df: pd.DataFrame,
task_dict: dict[str, str],
elem_embedding: str = "cgcnn92",
structure_col: str = "structure",
identifiers: Sequence[str] = (),
radius: float = 5,
radius_cutoff: float = 5,
max_num_nbr: int = 12,
dmin: float = 0,
step: float = 0.2,
):
"""Featurize crystal structures into neighborhood graphs with this data class
for CGCNN.

Args:
df (pd.Dataframe): Pandas dataframe holding input and target values.
task_dict ({target: task}): task dict for multi-task learning
elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16,
onehot112 or path to a file with custom element embeddings.
Defaults to matscholar200.
structure_col (str, optional): df column holding pymatgen Structure objects
as input.
identifiers (list[str], optional): df columns for distinguishing data
points. Will be copied over into the model's output CSV. Defaults to ().
radius (float, optional): Cut-off radius for neighborhood. Defaults to 5.
radius_cutoff (float, optional): Cut-off radius for neighborhood.
Defaults to 5.
max_num_nbr (int, optional): maximum number of neighbors to consider.
Defaults to 12.
dmin (float, optional): minimum distance in Gaussian basis. Defaults to 0.
step (float, optional): increment size of Gaussian basis. Defaults to 0.2.
"""
self.task_dict = task_dict
self.identifiers = list(identifiers)

self.radius = radius
self.radius_cutoff = radius_cutoff
self.max_num_nbr = max_num_nbr

if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"):
elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json"

with open(elem_embedding) as file:
self.elem_features = json.load(file)

for key, value in self.elem_features.items():
self.elem_features[key] = np.array(value, dtype=float)
if not hasattr(self, "elem_emb_len"):
self.elem_emb_len = len(value)
elif self.elem_emb_len != len(value):
raise ValueError(
f"Element embedding length mismatch: len({key})="
f"{len(value)}, expected {self.elem_emb_len}"
)

self.gaussian_dist_func = GaussianDistance(dmin=dmin, dmax=radius, step=step)
self.nbr_fea_dim = self.gaussian_dist_func.embedding_size

self.df = df
self.structure_col = structure_col

Expand All @@ -84,7 +55,7 @@ def __init__(
self.df[structure_col].items(), total=len(df), desc=desc, disable=None
):
self_idx, nbr_idx, _ = get_structure_neighbor_info(
struct, radius, max_num_nbr
struct, self.radius_cutoff, self.max_num_nbr
)
material_ids = [idx, *self.df.loc[idx][self.identifiers]]
if 0 in (len(self_idx), len(nbr_idx)):
Expand Down Expand Up @@ -140,16 +111,10 @@ def __getitem__(self, idx: int):
material_ids = [self.df.index[idx], *row[self.identifiers]]

# atom features for disordered sites
site_atoms = [atom.species.as_dict() for atom in struct]
atom_features = np.vstack(
[
np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0)
for site in site_atoms
]
)
atom_features = [atom.specie.Z for atom in struct]

self_idx, nbr_idx, nbr_dist = get_structure_neighbor_info(
struct, self.radius, self.max_num_nbr
struct, self.radius_cutoff, self.max_num_nbr
)

if len(self_idx) == 0:
Expand All @@ -161,9 +126,7 @@ def __getitem__(self, idx: int):
if set(self_idx) != set(range(len(struct))):
raise ValueError(f"At least one atom in {material_ids} is isolated")

nbr_dist = self.gaussian_dist_func.expand(nbr_dist)

atom_fea_t = Tensor(atom_features)
atom_fea_t = LongTensor(atom_features)
nbr_dist_t = Tensor(nbr_dist)
self_idx_t = LongTensor(self_idx)
nbr_idx_t = LongTensor(nbr_idx)
Expand Down Expand Up @@ -278,27 +241,25 @@ def __init__(
"Max radii below minimum radii + step size - please increase dmax."
)

self.filter = np.arange(dmin, dmax + step, step)
self.filter = torch.arange(dmin, dmax + step, step)
self.embedding_size = len(self.filter)

if var is None:
var = step

self.var = var

def expand(self, distances: np.ndarray) -> np.ndarray:
def expand(self, distances: Tensor) -> Tensor:
"""Apply Gaussian distance filter to a numpy distance array.

Args:
distances (ArrayLike): A distance matrix of any shape.

Returns:
np.ndarray: Expanded distance matrix with the last dimension of length
Tensor: Expanded distance matrix with the last dimension of length
len(self.filter)
"""
distances = np.array(distances)

return np.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)
return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)


def get_structure_neighbor_info(
Expand Down
30 changes: 26 additions & 4 deletions aviary/cgcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from pymatgen.util.due import Doi, due
from torch import LongTensor, Tensor, nn

from aviary.cgcnn.data import GaussianDistance
from aviary.core import BaseModelClass
from aviary.networks import SimpleNetwork
from aviary.scatter import scatter_reduce
from aviary.utils import get_element_embedding


@due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model")
Expand All @@ -25,8 +27,10 @@ def __init__(
self,
robust: bool,
n_targets: Sequence[int],
elem_emb_len: int,
nbr_fea_len: int,
elem_embedding: str = "cgcnn92",
radius_cutoff: float = 5.0,
radius_min: float = 0.0,
radius_step: float = 0.2,
elem_fea_len: int = 64,
n_graph: int = 4,
h_fea_len: int = 128,
Expand All @@ -42,8 +46,15 @@ def __init__(
(uncertainty inherent to the sample) which can be used with a robust
loss function to attenuate the weighting of uncertain samples.
n_targets (list[int]): Number of targets to train on
elem_emb_len (int): Number of atom features in the input.
nbr_fea_len (int): Number of bond features.
elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16,
onehot112 or path to a file with custom element embeddings.
Defaults to matscholar200.
radius_cutoff (float, optional): Cut-off radius for neighborhood.
Defaults to 5.
radius_min (float, optional): minimum distance in Gaussian basis.
Defaults to 0.
radius_step (float, optional): increment size of Gaussian basis.
Defaults to 0.2.
elem_fea_len (int, optional): Number of hidden atom features in the
convolutional layers. Defaults to 64.
n_graph (int, optional): Number of convolutional layers. Defaults to 4.
Expand All @@ -57,6 +68,14 @@ def __init__(
"""
super().__init__(robust=robust, **kwargs)

self.elem_embedding = get_element_embedding(elem_embedding)
elem_emb_len = self.elem_embedding.weight.shape[1]

self.gaussian_dist_func = GaussianDistance(
dmin=radius_min, dmax=radius_cutoff, step=radius_step
)
nbr_fea_len = self.gaussian_dist_func.embedding_size

desc_dict = {
"elem_emb_len": elem_emb_len,
"nbr_fea_len": nbr_fea_len,
Expand Down Expand Up @@ -107,6 +126,9 @@ def forward(
Returns:
tuple[Tensor, ...]: tuple of predictions for all targets
"""
nbr_fea = self.gaussian_dist_func.expand(nbr_fea)
atom_fea = self.elem_embedding(atom_fea)

atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)

crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean")
Expand Down
7 changes: 5 additions & 2 deletions aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
epoch: int = 0,
device: str | None = None,
best_val_scores: dict[str, float] | None = None,
**kwargs,
) -> None:
"""Store core model parameters.

Expand All @@ -47,6 +48,7 @@ def __init__(
device (str, optional): Device to store the model parameters on.
best_val_scores (dict[str, float], optional): Validation score to use for
early stopping. Defaults to None.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.task_dict = task_dict
Expand Down Expand Up @@ -299,8 +301,9 @@ def evaluate(
preds = output.squeeze(1)
loss = loss_func(preds, targets)

z_scored_error = preds - targets
error = normalizer.std * z_scored_error.data.cpu()
denormed_preds = normalizer.denorm(preds)
denormed_targets = normalizer.denorm(targets)
error = denormed_preds - denormed_targets
target_metrics["MAE"].append(float(error.abs().mean()))
target_metrics["MSE"].append(float(error.pow(2).mean()))

Expand Down
32 changes: 4 additions & 28 deletions aviary/roost/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from collections.abc import Sequence
from functools import cache
from typing import Any
Expand All @@ -10,8 +9,6 @@
from torch import LongTensor, Tensor
from torch.utils.data import Dataset

from aviary import PKG_DIR


class CompositionData(Dataset):
"""Dataset class for the Roost composition model."""
Expand All @@ -20,7 +17,6 @@ def __init__(
self,
df: pd.DataFrame,
task_dict: dict[str, str],
elem_embedding: str = "matscholar200",
inputs: str = "composition",
identifiers: Sequence[str] = ("material_id", "composition"),
):
Expand All @@ -47,14 +43,6 @@ def __init__(
self.identifiers = list(identifiers)
self.df = df

if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json"

with open(elem_embedding) as file:
self.elem_features = json.load(file)

self.elem_emb_len = len(next(iter(self.elem_features.values())))

self.n_targets = []
for target, task in self.task_dict.items():
if task == "regression":
Expand Down Expand Up @@ -88,24 +76,12 @@ def __getitem__(self, idx: int):
composition = row[self.inputs]
material_ids = row[self.identifiers].to_list()

comp_dict = Composition(composition).get_el_amt_dict()
elements = list(comp_dict)

comp_dict = Composition(composition).fractional_composition
weights = list(comp_dict.values())
weights = np.atleast_2d(weights).T / np.sum(weights)
elem_fea = [elem.Z for elem in comp_dict]

try:
elem_fea = np.vstack([self.elem_features[element] for element in elements])
except AssertionError as exc:
raise AssertionError(
f"{material_ids} contains element types not in embedding"
) from exc
except ValueError as exc:
raise ValueError(
f"{material_ids} composition cannot be parsed into elements"
) from exc

n_elems = len(elements)
n_elems = len(comp_dict)
self_idx = []
nbr_idx = []
for elem_idx in range(n_elems):
Expand All @@ -114,7 +90,7 @@ def __getitem__(self, idx: int):

# convert all data to tensors
elem_weights = Tensor(weights)
elem_fea = Tensor(elem_fea)
elem_fea = LongTensor(elem_fea)
self_idx = LongTensor(self_idx)
nbr_idx = LongTensor(nbr_idx)

Expand Down
8 changes: 7 additions & 1 deletion aviary/roost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aviary.core import BaseModelClass
from aviary.networks import ResidualNetwork, SimpleNetwork
from aviary.segments import MessageLayer, WeightedAttentionPooling
from aviary.utils import get_element_embedding


@due.dcite(Doi("10.1038/s41467-020-19964-7"), description="Roost model")
Expand All @@ -25,7 +26,7 @@ def __init__(
self,
robust: bool,
n_targets: Sequence[int],
elem_emb_len: int,
elem_embedding: str = "matscholar200",
elem_fea_len: int = 64,
n_graph: int = 3,
elem_heads: int = 3,
Expand All @@ -41,6 +42,8 @@ def __init__(
"""Composition-only model."""
super().__init__(robust=robust, **kwargs)

self.elem_embedding = get_element_embedding(elem_embedding)
elem_emb_len = self.elem_embedding.weight.shape[1]
desc_dict = {
"elem_emb_len": elem_emb_len,
"elem_fea_len": elem_fea_len,
Expand All @@ -60,6 +63,7 @@ def __init__(
"n_targets": n_targets,
"out_hidden": out_hidden,
"trunk_hidden": trunk_hidden,
"elem_embedding": elem_embedding,
**desc_dict,
}
self.model_params.update(model_params)
Expand All @@ -83,6 +87,8 @@ def forward(
cry_elem_idx: LongTensor,
) -> tuple[Tensor, ...]:
"""Forward pass through the material_nn and output_nn."""
elem_fea = self.elem_embedding(elem_fea)

crys_fea = self.material_nn(
elem_weights, elem_fea, self_idx, nbr_idx, cry_elem_idx
)
Expand Down
Loading