Skip to content

Commit b8a69d3

Browse files
authored
Merge pull request #110 from Exabyte-io/feature/SOF-7570-2
Feature/sof 7570 2 - Pydantic datamodel + related
2 parents 9e262b2 + d189de2 commit b8a69d3

File tree

17 files changed

+1003
-75
lines changed

17 files changed

+1003
-75
lines changed

.github/workflows/cicd.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-20.04
1212
strategy:
1313
matrix:
14-
python-version: [3.8.6]
14+
python-version: [3.10.13]
1515

1616
steps:
1717
- name: Checkout this repository
@@ -37,10 +37,9 @@ jobs:
3737
strategy:
3838
matrix:
3939
python-version:
40-
- 3.8.x
41-
- 3.9.x
4240
- 3.10.x
4341
- 3.11.x
42+
- 3.12.x
4443

4544
steps:
4645
- name: Checkout this repository

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ celerybeat.pid
126126

127127
# Environments
128128
.env
129-
.venv
129+
.venv*
130130
env/
131131
venv/
132132
ENV/
@@ -176,3 +176,4 @@ node_modules/
176176
*.DS_Store
177177

178178
tsconfig.tsbuildinfo
179+
.python-version

pyproject.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ dynamic = ["version"]
44
description = "COre DEfinitions."
55
readme = "README.md"
66
requires-python = ">=3.8"
7-
license = {file = "LICENSE.md"}
7+
license = { file = "LICENSE.md" }
88
authors = [
9-
{name = "Exabyte Inc.", email = "info@mat3ra.com"}
9+
{ name = "Exabyte Inc.", email = "info@mat3ra.com" }
1010
]
1111
classifiers = [
1212
"Programming Language :: Python",
@@ -18,6 +18,8 @@ dependencies = [
1818
# add requirements here
1919
"numpy",
2020
"jsonschema>=2.6.0",
21+
"pydantic>=2.10.5",
22+
"mat3ra-esse",
2123
"mat3ra-utils>=2024.5.15.post0",
2224
]
2325

@@ -79,3 +81,11 @@ target-version = "py38"
7981
profile = "black"
8082
multi_line_output = 3
8183
include_trailing_comma = true
84+
85+
[tool.pytest.ini_options]
86+
pythonpath = [
87+
"src/py",
88+
]
89+
testpaths = [
90+
"tests/py"
91+
]

src/py/mat3ra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
2+
# otherwise, `mat3ra.utils` path leads to an empty __init__.py file in the code.py package
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import json
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3+
4+
from mat3ra.utils.mixins import RoundNumericValuesMixin
5+
from pydantic import BaseModel, model_serializer
6+
7+
from .value_with_id import RoundedValueWithId, ValueWithId
8+
9+
10+
class ArrayWithIds(BaseModel):
11+
values: List[Any]
12+
ids: List[int]
13+
14+
@classmethod
15+
def from_values(cls, values: List[Any]) -> "ArrayWithIds":
16+
try:
17+
ids = list(range(len(values)))
18+
return cls(values=values, ids=ids)
19+
except KeyError:
20+
raise ValueError("Values must be a list")
21+
22+
@classmethod
23+
def get_values_and_ids_from_list_of_dicts(cls, list_of_dicts: List[Dict[str, Any]]) -> Tuple[List[Any], List[int]]:
24+
try:
25+
values = [item["value"] for item in list_of_dicts]
26+
ids = [item["id"] for item in list_of_dicts]
27+
return values, ids
28+
except KeyError:
29+
raise ValueError("List of dictionaries must contain 'id' and 'value' keys")
30+
31+
@classmethod
32+
def from_list_of_dicts(cls, list_of_dicts: List[Dict[str, Any]]) -> "ArrayWithIds":
33+
try:
34+
values, ids = cls.get_values_and_ids_from_list_of_dicts(list_of_dicts)
35+
return cls(values=values, ids=ids)
36+
except KeyError:
37+
raise ValueError("List of dictionaries must contain 'id' and 'value' keys")
38+
39+
@model_serializer
40+
def to_dict(self) -> List[Dict[str, Any]]:
41+
return list(map(lambda x: x.to_dict(), self.to_array_of_values_with_ids()))
42+
43+
def to_json(self, skip_rounding=True) -> str:
44+
return json.dumps(self.to_dict())
45+
46+
def to_array_of_values_with_ids(self) -> List[ValueWithId]:
47+
return [ValueWithId(id=id, value=item) for id, item in zip(self.ids, self.values)]
48+
49+
def get_element_value_by_index(self, index: int) -> Any:
50+
return self.values[index] if index < len(self.values) else None
51+
52+
def get_element_id_by_value(self, value: Any) -> Optional[int]:
53+
try:
54+
return self.ids[self.values.index(value)]
55+
except ValueError:
56+
return None
57+
58+
def filter_by_values(self, values: Union[List[Any], Any]):
59+
def make_hashable(value):
60+
return tuple(value) if isinstance(value, list) else value
61+
62+
values_to_keep = set(make_hashable(v) for v in values) if isinstance(values, list) else {make_hashable(values)}
63+
filtered_items = [(v, i) for v, i in zip(self.values, self.ids) if make_hashable(v) in values_to_keep]
64+
if filtered_items:
65+
values_unpacked, ids_unpacked = zip(*filtered_items)
66+
self.values = list(values_unpacked)
67+
self.ids = list(ids_unpacked)
68+
else:
69+
self.values = []
70+
self.ids = []
71+
72+
def filter_by_indices(self, indices: Union[List[int], int]):
73+
index_set = set(indices) if isinstance(indices, list) else {indices}
74+
self.values = [self.values[i] for i in range(len(self.values)) if i in index_set]
75+
self.ids = [self.ids[i] for i in range(len(self.ids)) if i in index_set]
76+
77+
def filter_by_ids(self, ids: Union[List[int], int], invert: bool = False):
78+
if isinstance(ids, int):
79+
ids = [ids]
80+
if not invert:
81+
ids_set = set(ids)
82+
else:
83+
ids_set = set(self.ids) - set(ids)
84+
keep_indices = [index for index, id_ in enumerate(self.ids) if id_ in ids_set]
85+
self.values = [self.values[index] for index in keep_indices]
86+
self.ids = [self.ids[index] for index in keep_indices]
87+
88+
def __eq__(self, other: object) -> bool:
89+
return isinstance(other, ArrayWithIds) and self.values == other.values and self.ids == other.ids
90+
91+
def map_array_in_place(self, func: Callable):
92+
self.values = list(map(func, self.values))
93+
94+
def add_item(self, element: Any, id: Optional[int] = None):
95+
if id is None:
96+
new_id = max(self.ids, default=-1) + 1
97+
else:
98+
new_id = id
99+
self.values.append(element)
100+
self.ids.append(new_id)
101+
102+
def remove_item(self, index: int, id: Optional[int] = None):
103+
if id is not None:
104+
try:
105+
index = self.ids.index(id)
106+
except ValueError:
107+
raise ValueError("ID not found in the list")
108+
if index < len(self.values):
109+
del self.values[index]
110+
del self.ids[index]
111+
else:
112+
raise IndexError("Index out of range")
113+
114+
115+
class RoundedArrayWithIds(RoundNumericValuesMixin, ArrayWithIds):
116+
def to_array_of_values_with_ids(self) -> List[ValueWithId]:
117+
class_reference = RoundedValueWithId
118+
class_reference.__round_precision__ = self.__round_precision__
119+
return [class_reference(id=id, value=item) for id, item in zip(self.ids, self.values)]

src/py/mat3ra/code/constants.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from math import pi
22

3+
from mat3ra.esse.models.definitions.constants import FundamentalConstants
4+
5+
CONSTANTS = FundamentalConstants()
6+
37

48
class Coefficients:
59
# Same as used in: JS/TS
@@ -13,18 +17,19 @@ class Coefficients:
1317
# and originally taken from https://github.com/hplgit/physical-quantities/blob/master/PhysicalQuantities.py
1418

1519
# Internal, for convenience purposes
16-
_c = 299792458.0 # speed of light, m/s
17-
_mu0 = 4.0e-7 * pi # permeability of vacuum
18-
_eps0 = 1 / _mu0 / _c**2 # permittivity of vacuum
19-
_Grav = 6.67259e-11 # gravitational constant
20-
_hplanck = 6.6260755e-34 # Planck constant, J s
21-
_hbar = _hplanck / (2 * pi) # Planck constant / 2pi, J s
22-
_e = 1.60217733e-19 # elementary charge
23-
_me = 9.1093897e-31 # electron mass
20+
_c = CONSTANTS.c # speed of light, m/s
21+
_Grav = CONSTANTS.G # gravitational constant
22+
_hplanck = CONSTANTS.h # Planck constant, J s
23+
_e = CONSTANTS.e # elementary charge
24+
_me = CONSTANTS.me # electron mass
25+
_mu0 = 4.0e-7 * pi # permeability of vacuum, atomic units
26+
2427
_mp = 1.6726231e-27 # proton mass
2528
_Nav = 6.0221367e23 # Avogadro number
2629
_k = 1.380658e-23 # Boltzmann constant, J/K
2730
_amu = 1.6605402e-27 # atomic mass unit, kg
31+
_eps0 = 1 / _mu0 / _c**2 # permittivity of vacuum
32+
_hbar = _hplanck / (2 * pi) # Planck constant / 2pi, J s
2833

2934
# External
3035
BOHR = 4e10 * pi * _eps0 * _hbar**2 / _me / _e**2 # Bohr radius in angstrom

src/py/mat3ra/code/entity.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,90 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Type, TypeVar
22

33
import jsonschema
44
from mat3ra.utils import object as object_utils
5+
from pydantic import BaseModel
6+
from typing_extensions import Self
57

68
from . import BaseUnderscoreJsonPropsHandler
79
from .mixins import DefaultableMixin, HasDescriptionMixin, HasMetadataMixin, NamedMixin
810

11+
T = TypeVar("T", bound="InMemoryEntityPydantic")
12+
B = TypeVar("B", bound="BaseModel")
913

14+
15+
# TODO: remove in the next PR
1016
class ValidationErrorCode:
1117
IN_MEMORY_ENTITY_DATA_INVALID = "IN_MEMORY_ENTITY_DATA_INVALID"
1218

1319

20+
# TODO: remove in the next PR
1421
class ErrorDetails:
1522
def __init__(self, error: Optional[Dict[str, Any]], json: Dict[str, Any], schema: Dict):
1623
self.error = error
1724
self.json = json
1825
self.schema = schema
1926

2027

28+
# TODO: remove in the next PR
2129
class EntityError(Exception):
2230
def __init__(self, code: ValidationErrorCode, details: Optional[ErrorDetails] = None):
2331
super().__init__(code)
2432
self.code = code
2533
self.details = details
2634

2735

36+
class InMemoryEntityPydantic(BaseModel):
37+
model_config = {"arbitrary_types_allowed": True}
38+
39+
@classmethod
40+
def create(cls: Type[T], config: Dict[str, Any]) -> T:
41+
return cls.validate(config)
42+
43+
@classmethod
44+
def validate(cls, value: Any) -> Self:
45+
# this will clean and validate data
46+
return cls.model_validate(value)
47+
48+
@classmethod
49+
def is_valid(cls, value: Any) -> bool:
50+
try:
51+
cls.validate(value)
52+
return True
53+
except Exception:
54+
return False
55+
56+
@classmethod
57+
def from_json(cls: Type[T], json_str: str) -> T:
58+
return cls.model_validate_json(json_str)
59+
60+
@classmethod
61+
def clean(cls: Type[T], config: Dict[str, Any]) -> Dict[str, Any]:
62+
validated_model = cls.model_validate(config)
63+
return validated_model.model_dump()
64+
65+
def get_schema(self) -> Dict[str, Any]:
66+
return self.model_json_schema()
67+
68+
def get_data_model(self) -> Type[B]:
69+
for base in self.__class__.__bases__:
70+
if issubclass(base, BaseModel) and base is not self.__class__:
71+
return base
72+
raise ValueError(f"No schema base model found for {self.__class__.__name__}")
73+
74+
def get_cls_name(self) -> str:
75+
return self.__class__.__name__
76+
77+
def to_dict(self, exclude: Optional[List[str]] = None) -> Dict[str, Any]:
78+
return self.model_dump(exclude=set(exclude) if exclude else None)
79+
80+
def to_json(self, exclude: Optional[List[str]] = None) -> str:
81+
return self.model_dump_json(exclude=set(exclude) if exclude else None)
82+
83+
def clone(self: T, extra_context: Optional[Dict[str, Any]] = None, deep=True) -> T:
84+
return self.model_copy(update=extra_context or {}, deep=deep)
85+
86+
87+
# TODO: remove in the next PR
2888
class InMemoryEntity(BaseUnderscoreJsonPropsHandler):
2989
jsonSchema: Optional[Dict] = None
3090

@@ -97,7 +157,7 @@ def get_as_entity_reference(self, by_id_only: bool = False) -> Dict[str, str]:
97157
return {"_id": self.id, "slug": self.slug, "cls": self.get_cls_name()}
98158

99159

100-
class HasDescriptionHasMetadataNamedDefaultableInMemoryEntity(
101-
InMemoryEntity, DefaultableMixin, NamedMixin, HasMetadataMixin, HasDescriptionMixin
160+
class HasDescriptionHasMetadataNamedDefaultableInMemoryEntityPydantic(
161+
InMemoryEntityPydantic, DefaultableMixin, NamedMixin, HasMetadataMixin, HasDescriptionMixin
102162
):
103163
pass

0 commit comments

Comments
 (0)