Skip to content

Commit 6c72299

Browse files
authored
Swap to using nn.Embedding (#103)
* fea: swap to use nn.embedding for roost. * fea: swap wren to use nn.embedding * clean: refactor into utils functions * test: add basic tests to the embedding functions * fea: swap cgcnn to do the embedding inside the forward pass * clean: clear notebook outputs * fea: swap to use hatch * fix: uppercase LICENSE * maint: uv install cpu torch * maint: also run tests on windows latest * maint: cancel concurrent tests * fix: avoid tensor copy warning
1 parent 451f573 commit 6c72299

21 files changed

+429
-346
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@ on:
88
paths: ["**/*.py", .github/workflows/test.yml]
99
branches: [main]
1010

11+
concurrency:
12+
# Cancel only on same PR number
13+
group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }}
14+
cancel-in-progress: true
15+
1116
jobs:
1217
tests:
1318
strategy:
1419
fail-fast: false
1520
matrix:
16-
os: [ubuntu-latest, macos-14]
21+
os: [ubuntu-latest, macos-latest, windows-latest]
1722
version:
1823
- { python: "3.10", resolution: highest }
1924
- { python: "3.12", resolution: lowest-direct }
@@ -33,7 +38,7 @@ jobs:
3338

3439
- name: Install dependencies
3540
run: |
36-
pip install torch --index-url https://download.pytorch.org/whl/cpu
41+
uv pip install torch --index-url https://download.pytorch.org/whl/cpu --system
3742
uv pip install .[test] --system
3843
3944
- name: Run Tests

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ repos:
4545
rev: 0.8.1
4646
hooks:
4747
- id: nbstripout
48-
args: [--drop-empty-cells, --keep-output]
48+
args: [--drop-empty-cells]

aviary/cgcnn/data.py

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import itertools
2-
import json
32
from collections.abc import Sequence
43
from functools import cache
54
from typing import Any
@@ -12,8 +11,6 @@
1211
from torch.utils.data import Dataset
1312
from tqdm import tqdm
1413

15-
from aviary import PKG_DIR
16-
1714

1815
class CrystalGraphData(Dataset):
1916
"""Dataset class for the CGCNN structure model."""
@@ -22,58 +19,32 @@ def __init__(
2219
self,
2320
df: pd.DataFrame,
2421
task_dict: dict[str, str],
25-
elem_embedding: str = "cgcnn92",
2622
structure_col: str = "structure",
2723
identifiers: Sequence[str] = (),
28-
radius: float = 5,
24+
radius_cutoff: float = 5,
2925
max_num_nbr: int = 12,
30-
dmin: float = 0,
31-
step: float = 0.2,
3226
):
3327
"""Featurize crystal structures into neighborhood graphs with this data class
3428
for CGCNN.
3529
3630
Args:
3731
df (pd.Dataframe): Pandas dataframe holding input and target values.
3832
task_dict ({target: task}): task dict for multi-task learning
39-
elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16,
40-
onehot112 or path to a file with custom element embeddings.
41-
Defaults to matscholar200.
4233
structure_col (str, optional): df column holding pymatgen Structure objects
4334
as input.
4435
identifiers (list[str], optional): df columns for distinguishing data
4536
points. Will be copied over into the model's output CSV. Defaults to ().
46-
radius (float, optional): Cut-off radius for neighborhood. Defaults to 5.
37+
radius_cutoff (float, optional): Cut-off radius for neighborhood.
38+
Defaults to 5.
4739
max_num_nbr (int, optional): maximum number of neighbors to consider.
4840
Defaults to 12.
49-
dmin (float, optional): minimum distance in Gaussian basis. Defaults to 0.
50-
step (float, optional): increment size of Gaussian basis. Defaults to 0.2.
5141
"""
5242
self.task_dict = task_dict
5343
self.identifiers = list(identifiers)
5444

55-
self.radius = radius
45+
self.radius_cutoff = radius_cutoff
5646
self.max_num_nbr = max_num_nbr
5747

58-
if elem_embedding in ("matscholar200", "cgcnn92", "megnet16", "onehot112"):
59-
elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json"
60-
61-
with open(elem_embedding) as file:
62-
self.elem_features = json.load(file)
63-
64-
for key, value in self.elem_features.items():
65-
self.elem_features[key] = np.array(value, dtype=float)
66-
if not hasattr(self, "elem_emb_len"):
67-
self.elem_emb_len = len(value)
68-
elif self.elem_emb_len != len(value):
69-
raise ValueError(
70-
f"Element embedding length mismatch: len({key})="
71-
f"{len(value)}, expected {self.elem_emb_len}"
72-
)
73-
74-
self.gaussian_dist_func = GaussianDistance(dmin=dmin, dmax=radius, step=step)
75-
self.nbr_fea_dim = self.gaussian_dist_func.embedding_size
76-
7748
self.df = df
7849
self.structure_col = structure_col
7950

@@ -84,7 +55,7 @@ def __init__(
8455
self.df[structure_col].items(), total=len(df), desc=desc, disable=None
8556
):
8657
self_idx, nbr_idx, _ = get_structure_neighbor_info(
87-
struct, radius, max_num_nbr
58+
struct, self.radius_cutoff, self.max_num_nbr
8859
)
8960
material_ids = [idx, *self.df.loc[idx][self.identifiers]]
9061
if 0 in (len(self_idx), len(nbr_idx)):
@@ -140,16 +111,10 @@ def __getitem__(self, idx: int):
140111
material_ids = [self.df.index[idx], *row[self.identifiers]]
141112

142113
# atom features for disordered sites
143-
site_atoms = [atom.species.as_dict() for atom in struct]
144-
atom_features = np.vstack(
145-
[
146-
np.sum([self.elem_features[el] * amt for el, amt in site.items()], axis=0)
147-
for site in site_atoms
148-
]
149-
)
114+
atom_features = [atom.specie.Z for atom in struct]
150115

151116
self_idx, nbr_idx, nbr_dist = get_structure_neighbor_info(
152-
struct, self.radius, self.max_num_nbr
117+
struct, self.radius_cutoff, self.max_num_nbr
153118
)
154119

155120
if len(self_idx) == 0:
@@ -161,9 +126,7 @@ def __getitem__(self, idx: int):
161126
if set(self_idx) != set(range(len(struct))):
162127
raise ValueError(f"At least one atom in {material_ids} is isolated")
163128

164-
nbr_dist = self.gaussian_dist_func.expand(nbr_dist)
165-
166-
atom_fea_t = Tensor(atom_features)
129+
atom_fea_t = LongTensor(atom_features)
167130
nbr_dist_t = Tensor(nbr_dist)
168131
self_idx_t = LongTensor(self_idx)
169132
nbr_idx_t = LongTensor(nbr_idx)
@@ -278,27 +241,25 @@ def __init__(
278241
"Max radii below minimum radii + step size - please increase dmax."
279242
)
280243

281-
self.filter = np.arange(dmin, dmax + step, step)
244+
self.filter = torch.arange(dmin, dmax + step, step)
282245
self.embedding_size = len(self.filter)
283246

284247
if var is None:
285248
var = step
286249

287250
self.var = var
288251

289-
def expand(self, distances: np.ndarray) -> np.ndarray:
252+
def expand(self, distances: Tensor) -> Tensor:
290253
"""Apply Gaussian distance filter to a numpy distance array.
291254
292255
Args:
293256
distances (ArrayLike): A distance matrix of any shape.
294257
295258
Returns:
296-
np.ndarray: Expanded distance matrix with the last dimension of length
259+
Tensor: Expanded distance matrix with the last dimension of length
297260
len(self.filter)
298261
"""
299-
distances = np.array(distances)
300-
301-
return np.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)
262+
return torch.exp(-((distances[..., None] - self.filter) ** 2) / self.var**2)
302263

303264

304265
def get_structure_neighbor_info(

aviary/cgcnn/model.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from pymatgen.util.due import Doi, due
66
from torch import LongTensor, Tensor, nn
77

8+
from aviary.cgcnn.data import GaussianDistance
89
from aviary.core import BaseModelClass
910
from aviary.networks import SimpleNetwork
1011
from aviary.scatter import scatter_reduce
12+
from aviary.utils import get_element_embedding
1113

1214

1315
@due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model")
@@ -25,8 +27,10 @@ def __init__(
2527
self,
2628
robust: bool,
2729
n_targets: Sequence[int],
28-
elem_emb_len: int,
29-
nbr_fea_len: int,
30+
elem_embedding: str = "cgcnn92",
31+
radius_cutoff: float = 5.0,
32+
radius_min: float = 0.0,
33+
radius_step: float = 0.2,
3034
elem_fea_len: int = 64,
3135
n_graph: int = 4,
3236
h_fea_len: int = 128,
@@ -42,8 +46,15 @@ def __init__(
4246
(uncertainty inherent to the sample) which can be used with a robust
4347
loss function to attenuate the weighting of uncertain samples.
4448
n_targets (list[int]): Number of targets to train on
45-
elem_emb_len (int): Number of atom features in the input.
46-
nbr_fea_len (int): Number of bond features.
49+
elem_embedding (str, optional): One of matscholar200, cgcnn92, megnet16,
50+
onehot112 or path to a file with custom element embeddings.
51+
Defaults to matscholar200.
52+
radius_cutoff (float, optional): Cut-off radius for neighborhood.
53+
Defaults to 5.
54+
radius_min (float, optional): minimum distance in Gaussian basis.
55+
Defaults to 0.
56+
radius_step (float, optional): increment size of Gaussian basis.
57+
Defaults to 0.2.
4758
elem_fea_len (int, optional): Number of hidden atom features in the
4859
convolutional layers. Defaults to 64.
4960
n_graph (int, optional): Number of convolutional layers. Defaults to 4.
@@ -57,6 +68,14 @@ def __init__(
5768
"""
5869
super().__init__(robust=robust, **kwargs)
5970

71+
self.elem_embedding = get_element_embedding(elem_embedding)
72+
elem_emb_len = self.elem_embedding.weight.shape[1]
73+
74+
self.gaussian_dist_func = GaussianDistance(
75+
dmin=radius_min, dmax=radius_cutoff, step=radius_step
76+
)
77+
nbr_fea_len = self.gaussian_dist_func.embedding_size
78+
6079
desc_dict = {
6180
"elem_emb_len": elem_emb_len,
6281
"nbr_fea_len": nbr_fea_len,
@@ -107,6 +126,9 @@ def forward(
107126
Returns:
108127
tuple[Tensor, ...]: tuple of predictions for all targets
109128
"""
129+
nbr_fea = self.gaussian_dist_func.expand(nbr_fea)
130+
atom_fea = self.elem_embedding(atom_fea)
131+
110132
atom_fea = self.node_nn(atom_fea, nbr_fea, self_idx, nbr_idx)
111133

112134
crys_fea = scatter_reduce(atom_fea, crystal_atom_idx, dim=0, reduce="mean")

aviary/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
epoch: int = 0,
3333
device: str | None = None,
3434
best_val_scores: dict[str, float] | None = None,
35+
**kwargs,
3536
) -> None:
3637
"""Store core model parameters.
3738
@@ -47,6 +48,7 @@ def __init__(
4748
device (str, optional): Device to store the model parameters on.
4849
best_val_scores (dict[str, float], optional): Validation score to use for
4950
early stopping. Defaults to None.
51+
**kwargs: Additional keyword arguments.
5052
"""
5153
super().__init__()
5254
self.task_dict = task_dict
@@ -299,8 +301,9 @@ def evaluate(
299301
preds = output.squeeze(1)
300302
loss = loss_func(preds, targets)
301303

302-
z_scored_error = preds - targets
303-
error = normalizer.std * z_scored_error.data.cpu()
304+
denormed_preds = normalizer.denorm(preds)
305+
denormed_targets = normalizer.denorm(targets)
306+
error = denormed_preds - denormed_targets
304307
target_metrics["MAE"].append(float(error.abs().mean()))
305308
target_metrics["MSE"].append(float(error.pow(2).mean()))
306309

aviary/roost/data.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from collections.abc import Sequence
32
from functools import cache
43
from typing import Any
@@ -10,8 +9,6 @@
109
from torch import LongTensor, Tensor
1110
from torch.utils.data import Dataset
1211

13-
from aviary import PKG_DIR
14-
1512

1613
class CompositionData(Dataset):
1714
"""Dataset class for the Roost composition model."""
@@ -20,7 +17,6 @@ def __init__(
2017
self,
2118
df: pd.DataFrame,
2219
task_dict: dict[str, str],
23-
elem_embedding: str = "matscholar200",
2420
inputs: str = "composition",
2521
identifiers: Sequence[str] = ("material_id", "composition"),
2622
):
@@ -47,14 +43,6 @@ def __init__(
4743
self.identifiers = list(identifiers)
4844
self.df = df
4945

50-
if elem_embedding in ["matscholar200", "cgcnn92", "megnet16", "onehot112"]:
51-
elem_embedding = f"{PKG_DIR}/embeddings/element/{elem_embedding}.json"
52-
53-
with open(elem_embedding) as file:
54-
self.elem_features = json.load(file)
55-
56-
self.elem_emb_len = len(next(iter(self.elem_features.values())))
57-
5846
self.n_targets = []
5947
for target, task in self.task_dict.items():
6048
if task == "regression":
@@ -88,24 +76,12 @@ def __getitem__(self, idx: int):
8876
composition = row[self.inputs]
8977
material_ids = row[self.identifiers].to_list()
9078

91-
comp_dict = Composition(composition).get_el_amt_dict()
92-
elements = list(comp_dict)
93-
79+
comp_dict = Composition(composition).fractional_composition
9480
weights = list(comp_dict.values())
9581
weights = np.atleast_2d(weights).T / np.sum(weights)
82+
elem_fea = [elem.Z for elem in comp_dict]
9683

97-
try:
98-
elem_fea = np.vstack([self.elem_features[element] for element in elements])
99-
except AssertionError as exc:
100-
raise AssertionError(
101-
f"{material_ids} contains element types not in embedding"
102-
) from exc
103-
except ValueError as exc:
104-
raise ValueError(
105-
f"{material_ids} composition cannot be parsed into elements"
106-
) from exc
107-
108-
n_elems = len(elements)
84+
n_elems = len(comp_dict)
10985
self_idx = []
11086
nbr_idx = []
11187
for elem_idx in range(n_elems):
@@ -114,7 +90,7 @@ def __getitem__(self, idx: int):
11490

11591
# convert all data to tensors
11692
elem_weights = Tensor(weights)
117-
elem_fea = Tensor(elem_fea)
93+
elem_fea = LongTensor(elem_fea)
11894
self_idx = LongTensor(self_idx)
11995
nbr_idx = LongTensor(nbr_idx)
12096

aviary/roost/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aviary.core import BaseModelClass
99
from aviary.networks import ResidualNetwork, SimpleNetwork
1010
from aviary.segments import MessageLayer, WeightedAttentionPooling
11+
from aviary.utils import get_element_embedding
1112

1213

1314
@due.dcite(Doi("10.1038/s41467-020-19964-7"), description="Roost model")
@@ -25,7 +26,7 @@ def __init__(
2526
self,
2627
robust: bool,
2728
n_targets: Sequence[int],
28-
elem_emb_len: int,
29+
elem_embedding: str = "matscholar200",
2930
elem_fea_len: int = 64,
3031
n_graph: int = 3,
3132
elem_heads: int = 3,
@@ -41,6 +42,8 @@ def __init__(
4142
"""Composition-only model."""
4243
super().__init__(robust=robust, **kwargs)
4344

45+
self.elem_embedding = get_element_embedding(elem_embedding)
46+
elem_emb_len = self.elem_embedding.weight.shape[1]
4447
desc_dict = {
4548
"elem_emb_len": elem_emb_len,
4649
"elem_fea_len": elem_fea_len,
@@ -60,6 +63,7 @@ def __init__(
6063
"n_targets": n_targets,
6164
"out_hidden": out_hidden,
6265
"trunk_hidden": trunk_hidden,
66+
"elem_embedding": elem_embedding,
6367
**desc_dict,
6468
}
6569
self.model_params.update(model_params)
@@ -83,6 +87,8 @@ def forward(
8387
cry_elem_idx: LongTensor,
8488
) -> tuple[Tensor, ...]:
8589
"""Forward pass through the material_nn and output_nn."""
90+
elem_fea = self.elem_embedding(elem_fea)
91+
8692
crys_fea = self.material_nn(
8793
elem_weights, elem_fea, self_idx, nbr_idx, cry_elem_idx
8894
)

0 commit comments

Comments
 (0)