Skip to content

Commit 451f573

Browse files
Spring cleaning (#101)
* fea: remove protostructure code that was upstreamed to pmg. * clean: mat_bench -> matbench * wip: bump py310, remove wrenformer trainer from package, clean up tests to have one file per model, fix matbench_example for wrenformer with the custom trainer * clean: remove _description_ auto docstrings * fea: pass data loaders to the trainer to try make it easier to apply to the InMemoryDataLoader of the wrenformer * fea: working wrenformer notebook example * wip: fix wrenformer tests * fea: wrenformer working with the aviary trainer and examples * fea: don't use dropout in wrenformer by default * tests: fix ci python version * tests: fix python version for ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: add typing-extensions to allow for Self type to be used. * fix: Self type hint --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e680e17 commit 451f573

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+2731
-2018
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,6 @@ examples/**/job-logs/
3636
examples/**/artifacts/
3737
examples/**/*.csv
3838
wandb/
39+
40+
# profiling
41+
*.prof

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
[![GitHub last commit](https://img.shields.io/github/last-commit/comprhys/aviary?label=Last+Commit)](https://github.com/comprhys/aviary/commits)
88
[![Tests](https://github.com/CompRhys/aviary/actions/workflows/test.yml/badge.svg)](https://github.com/CompRhys/aviary/actions/workflows/test.yml)
99
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/CompRhys/aviary/main.svg)](https://results.pre-commit.ci/latest/github/CompRhys/aviary/main)
10-
[![This project supports Python 3.9+](https://img.shields.io/badge/Python-3.9+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
10+
[![This project supports Python 3.10+](https://img.shields.io/badge/Python-3.10+-blue.svg?logo=python&logoColor=white)](https://python.org/downloads)
1111

1212
</h4>
1313

@@ -50,6 +50,10 @@ python examples/roost-example.py --train --evaluate --data-path examples/inputs/
5050
python examples/wren-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10
5151
```
5252

53+
```sh
54+
python examples/wrenformer-example.py --train --evaluate --data-path examples/inputs/examples.csv --targets E_f --tasks regression --losses L1 --robust --epoch 10
55+
```
56+
5357
```sh
5458
python examples/cgcnn-example.py --train --evaluate --data-path examples/inputs/examples.json --targets E_f --tasks regression --losses L1 --robust --epoch 10
5559
```

aviary/cgcnn/data.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
1-
from __future__ import annotations
2-
31
import itertools
42
import json
3+
from collections.abc import Sequence
54
from functools import cache
6-
from typing import TYPE_CHECKING, Any
5+
from typing import Any
76

87
import numpy as np
8+
import pandas as pd
99
import torch
10+
from pymatgen.core import Structure
1011
from torch import LongTensor, Tensor
1112
from torch.utils.data import Dataset
1213
from tqdm import tqdm
1314

1415
from aviary import PKG_DIR
1516

16-
if TYPE_CHECKING:
17-
from collections.abc import Sequence
18-
19-
import pandas as pd
20-
from pymatgen.core import Structure
21-
2217

2318
class CrystalGraphData(Dataset):
2419
"""Dataset class for the CGCNN structure model."""
@@ -253,9 +248,10 @@ def collate_batch(
253248
return (
254249
(atom_fea, nbr_dist, self_idx, nbr_idx, cry_idx),
255250
tuple(
256-
torch.stack(b_target, dim=0).to(device) for b_target in zip(*batch_targets)
251+
torch.stack(b_target, dim=0).to(device)
252+
for b_target in zip(*batch_targets, strict=False)
257253
),
258-
*zip(*batch_identifiers),
254+
*zip(*batch_identifiers, strict=False),
259255
)
260256

261257

@@ -332,10 +328,12 @@ def get_structure_neighbor_info(
332328
_neighbor_dists: list[float] = []
333329

334330
for _, idx_group in itertools.groupby( # group by site index
335-
zip(site_indices, neighbor_indices, neighbor_dists), key=lambda x: x[0]
331+
zip(site_indices, neighbor_indices, neighbor_dists, strict=False),
332+
key=lambda x: x[0],
336333
):
337334
site_indices, neighbor_idx, neighbor_dist = zip(
338-
*sorted(idx_group, key=lambda x: x[2]) # sort by distance
335+
*sorted(idx_group, key=lambda x: x[2]),
336+
strict=False, # sort by distance
339337
)
340338
_center_indices.extend(site_indices[:max_num_nbr])
341339
_neighbor_indices.extend(neighbor_idx[:max_num_nbr])

aviary/cgcnn/model.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from __future__ import annotations
2-
3-
from typing import TYPE_CHECKING
1+
from collections.abc import Sequence
42

53
import torch
64
import torch.nn.functional as F
@@ -11,9 +9,6 @@
119
from aviary.networks import SimpleNetwork
1210
from aviary.scatter import scatter_reduce
1311

14-
if TYPE_CHECKING:
15-
from collections.abc import Sequence
16-
1712

1813
@due.dcite(Doi("10.1103/PhysRevLett.120.145301"), description="CGCNN model")
1914
class CrystalGraphConvNet(BaseModelClass):
@@ -215,7 +210,7 @@ def forward(
215210
Args:
216211
atom_in_fea (Tensor): Atom hidden features before convolution
217212
nbr_fea (Tensor): Bond features of each atom's neighbors
218-
self_idx (LongTensor): _description_
213+
self_idx (LongTensor): Indices of the atom's self
219214
nbr_idx (LongTensor): Indices of M neighbors of each atom
220215
221216
Returns:

aviary/core.py

Lines changed: 40 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,23 @@
1-
from __future__ import annotations
2-
31
import gc
42
import os
53
import shutil
64
from abc import ABC
75
from collections import defaultdict
8-
from typing import TYPE_CHECKING, Any, Callable, Literal
6+
from collections.abc import Callable, Mapping
7+
from typing import Any, Literal
98

109
import numpy as np
1110
import torch
1211
import wandb
1312
from sklearn.metrics import f1_score
1413
from torch import BoolTensor, Tensor, nn
1514
from torch.nn.functional import softmax
15+
from torch.utils.data import DataLoader
1616
from torch.utils.tensorboard import SummaryWriter
1717
from tqdm import tqdm
1818

1919
from aviary import ROOT
20-
21-
if TYPE_CHECKING:
22-
from collections.abc import Mapping
23-
24-
from torch.utils.data import DataLoader
25-
26-
from aviary.data import InMemoryDataLoader
20+
from aviary.data import InMemoryDataLoader, Normalizer
2721

2822
TaskType = Literal["regression", "classification"]
2923

@@ -129,6 +123,14 @@ def fit(
129123
for metric, val in metrics.items():
130124
writer.add_scalar(f"{task}/train/{metric}", val, epoch)
131125

126+
if writer == "wandb":
127+
flat_train_metrics = {}
128+
for task, metrics in train_metrics.items():
129+
for metric, val in metrics.items():
130+
flat_train_metrics[f"train_{task}_{metric.lower()}"] = val
131+
flat_train_metrics["epoch"] = epoch
132+
wandb.log(flat_train_metrics)
133+
132134
# Validation
133135
if val_loader is not None:
134136
with torch.no_grad():
@@ -149,6 +151,14 @@ def fit(
149151
f"{task}/validation/{metric}", val, epoch
150152
)
151153

154+
if writer == "wandb":
155+
flat_val_metrics = {}
156+
for task, metrics in val_metrics.items():
157+
for metric, val in metrics.items():
158+
flat_val_metrics[f"val_{task}_{metric.lower()}"] = val
159+
flat_val_metrics["epoch"] = epoch
160+
wandb.log(flat_val_metrics)
161+
152162
# TODO test all tasks to see if they are best,
153163
# save a best model if any is best.
154164
# TODO what are the costs of this approach.
@@ -207,9 +217,6 @@ def fit(
207217
# catch memory leak
208218
gc.collect()
209219

210-
if writer == "wandb":
211-
wandb.log({"train": train_metrics, "validation": val_metrics})
212-
213220
except KeyboardInterrupt:
214221
pass
215222

@@ -271,7 +278,11 @@ def evaluate(
271278
mixed_loss: Tensor = 0 # type: ignore[assignment]
272279

273280
for target_name, targets, output, normalizer in zip(
274-
self.target_names, targets_list, outputs, normalizer_dict.values()
281+
self.target_names,
282+
targets_list,
283+
outputs,
284+
normalizer_dict.values(),
285+
strict=False,
275286
):
276287
task, loss_func = loss_dict[target_name]
277288
target_metrics = epoch_metrics[target_name]
@@ -318,7 +329,7 @@ def evaluate(
318329
else:
319330
raise ValueError(f"invalid task: {task}")
320331

321-
epoch_metrics[target_name]["Loss"].append(loss.cpu().item())
332+
target_metrics["Loss"].append(loss.cpu().item())
322333

323334
# NOTE multitasking currently just uses a direct sum of individual
324335
# target losses this should be okay but is perhaps sub-optimal
@@ -396,11 +407,13 @@ def predict(
396407
# for multitask learning
397408
targets = tuple(
398409
torch.cat(targets, dim=0).view(-1).cpu().numpy()
399-
for targets in zip(*test_targets)
410+
for targets in zip(*test_targets, strict=False)
411+
)
412+
predictions = tuple(
413+
torch.cat(preds, dim=0) for preds in zip(*test_preds, strict=False)
400414
)
401-
predictions = tuple(torch.cat(preds, dim=0) for preds in zip(*test_preds))
402415
# identifier columns
403-
ids = tuple(np.concatenate(x) for x in zip(*test_ids))
416+
ids = tuple(np.concatenate(x) for x in zip(*test_ids, strict=False))
404417
return targets, predictions, ids
405418

406419
@torch.no_grad()
@@ -445,83 +458,6 @@ def __repr__(self) -> str:
445458
return f"{cls_name} with {n_params:,} trainable params at {n_epochs:,} epochs"
446459

447460

448-
class Normalizer:
449-
"""Normalize a Tensor and restore it later."""
450-
451-
def __init__(self) -> None:
452-
"""Initialize Normalizer with mean 0 and std 1."""
453-
self.mean = torch.tensor(0)
454-
self.std = torch.tensor(1)
455-
456-
def fit(self, tensor: Tensor, dim: int = 0, keepdim: bool = False) -> None:
457-
"""Compute the mean and standard deviation of the given tensor.
458-
459-
Args:
460-
tensor (Tensor): Tensor to determine the mean and standard deviation over.
461-
dim (int, optional): Which dimension to take mean and standard deviation
462-
over. Defaults to 0.
463-
keepdim (bool, optional): Whether to keep the reduced dimension in Tensor.
464-
Defaults to False.
465-
"""
466-
self.mean = torch.mean(tensor, dim, keepdim)
467-
self.std = torch.std(tensor, dim, keepdim)
468-
469-
def norm(self, tensor: Tensor) -> Tensor:
470-
"""Normalize a Tensor.
471-
472-
Args:
473-
tensor (Tensor): Tensor to be normalized
474-
475-
Returns:
476-
Tensor: Normalized Tensor
477-
"""
478-
return (tensor - self.mean) / self.std
479-
480-
def denorm(self, normed_tensor: Tensor) -> Tensor:
481-
"""Restore normalized Tensor to original.
482-
483-
Args:
484-
normed_tensor (Tensor): Tensor to be restored
485-
486-
Returns:
487-
Tensor: Restored Tensor
488-
"""
489-
return normed_tensor * self.std + self.mean
490-
491-
def state_dict(self) -> dict[str, Tensor]:
492-
"""Get Normalizer parameters mean and std.
493-
494-
Returns:
495-
dict[str, Tensor]: Dictionary storing Normalizer parameters.
496-
"""
497-
return {"mean": self.mean, "std": self.std}
498-
499-
def load_state_dict(self, state_dict: dict[str, Tensor]) -> None:
500-
"""Overwrite Normalizer parameters given a new state_dict.
501-
502-
Args:
503-
state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
504-
"""
505-
self.mean = state_dict["mean"].cpu()
506-
self.std = state_dict["std"].cpu()
507-
508-
@classmethod
509-
def from_state_dict(cls, state_dict: dict[str, Tensor]) -> Normalizer:
510-
"""Create a new Normalizer given a state_dict.
511-
512-
Args:
513-
state_dict (dict[str, Tensor]): Dictionary storing Normalizer parameters.
514-
515-
Returns:
516-
Normalizer
517-
"""
518-
instance = cls()
519-
instance.mean = state_dict["mean"].cpu()
520-
instance.std = state_dict["std"].cpu()
521-
522-
return instance
523-
524-
525461
def save_checkpoint(
526462
state: dict[str, Any], is_best: bool, model_name: str, run_id: int
527463
) -> None:
@@ -662,3 +598,12 @@ def masked_min(x: Tensor, mask: BoolTensor, dim: int = 0) -> Tensor:
662598
x_inf = x.float().masked_fill(~mask, float("inf"))
663599
x_min, _ = x_inf.min(dim=dim)
664600
return x_min
601+
602+
603+
AGGREGATORS: dict[str, Callable[[Tensor, BoolTensor, int], Tensor]] = {
604+
"mean": masked_mean,
605+
"std": masked_std,
606+
"max": masked_max,
607+
"min": masked_min,
608+
"sum": lambda x, mask, dim: (x * mask).sum(dim=dim),
609+
}

0 commit comments

Comments
 (0)