Skip to content

Commit e680e17

Browse files
authored
Remove protostructure code that was upstreamed to pymatgen. (#100)
* fea: remove protostructure code that was upstreamed to pmg. * clean: mat_bench -> matbench * tests: fix python version for ci
1 parent 44ec886 commit e680e17

23 files changed

+78
-17926
lines changed

.github/workflows/test.yml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,26 @@ on:
1010

1111
jobs:
1212
tests:
13-
runs-on: ubuntu-latest
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
os: [ubuntu-latest, macos-14]
17+
version:
18+
- { python: "3.10", resolution: highest }
19+
- { python: "3.12", resolution: lowest-direct }
20+
runs-on: ${{ matrix.os }}
21+
1422
steps:
1523
- name: Check out repo
1624
uses: actions/checkout@v4
1725

1826
- name: Set up Python
1927
uses: actions/setup-python@v5
2028
with:
21-
python-version: 3.9
22-
cache: pip
23-
cache-dependency-path: pyproject.toml
29+
python-version: ${{ matrix.version.python }}
2430

25-
- name: Install uv
26-
run: pip install uv
31+
- name: Set up uv
32+
uses: astral-sh/setup-uv@v2
2733

2834
- name: Install dependencies
2935
run: |

aviary/wren/data.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
import numpy as np
99
import torch
10+
from pymatgen.analysis.prototypes import (
11+
RE_SUBST_ONE_PREFIX,
12+
RE_WYCKOFF_NO_PREFIX,
13+
WYCKOFF_MULTIPLICITY_DICT,
14+
WYCKOFF_POSITION_RELAB_DICT,
15+
)
1016
from torch import LongTensor, Tensor
1117
from torch.utils.data import Dataset
1218

1319
from aviary import PKG_DIR
14-
from aviary.wren.utils import (
15-
RE_SUBST_ONE_PREFIX,
16-
RE_WYCKOFF_NO_PREFIX,
17-
relab_dict,
18-
wyckoff_multiplicity_dict,
19-
)
2020

2121
if TYPE_CHECKING:
2222
from collections.abc import Sequence
@@ -300,13 +300,13 @@ def parse_protostructure_label(
300300
elements.extend([el] * mult)
301301
wyckoff_set.extend([letter] * mult)
302302
wyckoff_site_multiplicities.extend(
303-
[float(wyckoff_multiplicity_dict[spg_num][letter])] * mult
303+
[float(WYCKOFF_MULTIPLICITY_DICT[spg_num][letter])] * mult
304304
)
305305

306306
# Create augmented Wyckoff set
307307
augmented_wyckoff_set = {
308308
tuple(",".join(wyckoff_set).translate(str.maketrans(trans)).split(","))
309-
for trans in relab_dict[spg_num]
309+
for trans in WYCKOFF_POSITION_RELAB_DICT[spg_num]
310310
}
311311

312312
return spg_num, wyckoff_site_multiplicities, elements, list(augmented_wyckoff_set)

0 commit comments

Comments
 (0)