Skip to content

Commit dde887c

Browse files
authored
Merge pull request #112 from Exabyte-io/feature/SOF-7570-fix
fix: clean cleans config + tests
2 parents a1c2c66 + a761d16 commit dde887c

File tree

5 files changed

+100
-3
lines changed

5 files changed

+100
-3
lines changed

src/py/mat3ra/code/entity.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class InMemoryEntityPydantic(BaseModel):
3838

3939
@classmethod
4040
def create(cls: Type[T], config: Dict[str, Any]) -> T:
41-
return cls.validate(config)
41+
cleaned_data = cls.clean(config)
42+
return cls.validate(cleaned_data)
4243

4344
@classmethod
4445
def validate(cls, value: Any) -> Self:
@@ -59,8 +60,9 @@ def from_json(cls: Type[T], json_str: str) -> T:
5960

6061
@classmethod
6162
def clean(cls: Type[T], config: Dict[str, Any]) -> Dict[str, Any]:
62-
validated_model = cls.model_validate(config)
63-
return validated_model.model_dump()
63+
# Validate the config; extra keys are dropped and defaults are substituted.
64+
validated = cls.model_validate(config, strict=False)
65+
return validated.model_dump(exclude_unset=False)
6466

6567
def get_schema(self) -> Dict[str, Any]:
6668
return self.model_json_schema()

src/py/mat3ra/code/vector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def __eq__(self, other):
3333
other = Vector3D(other)
3434
return np.allclose(self.root, other.root, atol=self.__atol__, rtol=0)
3535

36+
@property
37+
def norm(self):
38+
return np.linalg.norm(self.value)
39+
3640

3741
class RoundedVector3D(RoundNumericValuesMixin, Vector3D):
3842
def __init__(self, root: List[float]):
@@ -64,3 +68,7 @@ def __eq__(self, other):
6468
other = RoundedVector3D(other)
6569
atol = self.__atol__ or 10 ** (-self.__round_precision__)
6670
return np.allclose(self.value_rounded, other.value_rounded, atol=atol, rtol=0)
71+
72+
@property
73+
def norm_rounded(self):
74+
return self.round_array_or_number(self.norm)

tests/py/unit/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pydantic import BaseModel
55

66
REFERENCE_OBJECT_VALID = {"key1": "value1", "key2": 1}
7+
REFERENCE_OBJECT_VALID_WITH_EXTRA_KEY = {"key1": "value1", "key2": 1, "key-to-clean": "will-be-removed"}
8+
REFERENCE_OBJECT_VALID_WITH_MISSING_KEY = {"key2": 1}
79
REFERENCE_OBJECT_VALID_UPDATED = {"key1": "value1-updated", "key2": 2}
810
REFERENCE_OBJECT_INVALID = {"key1": "value1", "key2": "value2"}
911
REFERENCE_OBJECT_VALID_JSON = json.dumps(REFERENCE_OBJECT_VALID)
@@ -18,6 +20,11 @@ class ExampleSchema(BaseModel):
1820
key2: int
1921

2022

23+
class ExampleDefaultableSchema(BaseModel):
24+
key1: str = "value1"
25+
key2: int
26+
27+
2128
class ExampleNestedSchema(BaseModel):
2229
nested_key1: ExampleSchema
2330

@@ -30,6 +37,10 @@ class ExampleClass(ExampleSchema, InMemoryEntityPydantic):
3037
pass
3138

3239

40+
class ExampleDefaultableClass(ExampleDefaultableSchema, InMemoryEntityPydantic):
41+
pass
42+
43+
3344
class ExampleNestedClass(ExampleNestedSchema, InMemoryEntityPydantic):
3445
@property
3546
def nested_key1_instance(self) -> ExampleClass:

tests/py/unit/test_entity.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
REFERENCE_OBJECT_VALID,
99
REFERENCE_OBJECT_VALID_JSON,
1010
REFERENCE_OBJECT_VALID_UPDATED,
11+
REFERENCE_OBJECT_VALID_WITH_EXTRA_KEY,
12+
REFERENCE_OBJECT_VALID_WITH_MISSING_KEY,
1113
ExampleClass,
14+
ExampleDefaultableClass,
1215
ExampleDoubleNestedKeyAsClassInstancesClass,
1316
ExampleDoubleNestedSchema,
1417
ExampleNestedClass,
@@ -65,6 +68,15 @@ def test_update_nested_as_class_instance():
6568
assert isinstance(entity.nested_key1, ExampleClass)
6669

6770

71+
def test_create_with_default():
72+
entity = ExampleDefaultableClass.create(REFERENCE_OBJECT_VALID_WITH_MISSING_KEY)
73+
assert isinstance(entity, ExampleDefaultableClass)
74+
# Default value for key1 -- "value1" is used
75+
assert entity.key1 == "value1"
76+
assert entity.key2 == 1
77+
assert isinstance(entity, ExampleDefaultableClass)
78+
79+
6880
def test_validate():
6981
# Test valid case
7082
entity = ExampleClass.create(REFERENCE_OBJECT_VALID)
@@ -111,6 +123,28 @@ def test_clean():
111123
assert True # Expecting an exception for invalid input
112124

113125

126+
def test_clean_extra_keys():
127+
# Test clean method with valid input with extra keys
128+
cleaned_data_with_extra = ExampleClass.clean(REFERENCE_OBJECT_VALID_WITH_EXTRA_KEY)
129+
assert isinstance(cleaned_data_with_extra, dict)
130+
assert cleaned_data_with_extra == REFERENCE_OBJECT_VALID
131+
assert "key-to-clean" not in cleaned_data_with_extra
132+
133+
134+
def test_clean_default_substitution():
135+
# Test case with default substitution (should add pass and add default values)
136+
cleaned_data_with_default = ExampleDefaultableClass.clean(REFERENCE_OBJECT_VALID_WITH_MISSING_KEY)
137+
assert isinstance(cleaned_data_with_default, dict)
138+
assert cleaned_data_with_default == REFERENCE_OBJECT_VALID
139+
140+
# Test case with invalid input with missing keys (should raise an error)
141+
try:
142+
_ = ExampleDefaultableClass.clean(REFERENCE_OBJECT_INVALID)
143+
assert False, "Invalid input did not raise an exception"
144+
except Exception:
145+
assert True
146+
147+
114148
def test_get_cls_name():
115149
# Test get_cls_name method
116150
entity = ExampleClass.create(REFERENCE_OBJECT_VALID)

tests/py/unit/test_vector.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import numpy as np
12
from mat3ra.code.vector import RoundedVector3D, Vector3D
23

34
VECTOR_FLOAT = [1.234567890, 2.345678901, 3.456789012]
5+
VECTOR_FLOAT_NORM = 4.3561172682906
6+
FLOAT_PRECISION = 1e-8
7+
48
VECTOR_FLOAT_DIFFERENT_WITHIN_TOL = [1.23456789999, 2.345678901, 3.456789012]
59
VECTOR_FLOAT_DIFFERENT_OUTSIDE_TOL = [1.2345699999, 2.345678901, 3.456789012]
610
VECTOR_FLOAT_ROUNDED_4 = [1.2346, 2.3457, 3.4568]
@@ -38,6 +42,28 @@ def test_vector_equality():
3842
assert vector != Vector3D(VECTOR_FLOAT_DIFFERENT_OUTSIDE_TOL)
3943

4044

45+
def test_vector_equality_with_list():
46+
vector = Vector3D(VECTOR_FLOAT)
47+
assert vector == VECTOR_FLOAT
48+
assert vector == VECTOR_FLOAT_DIFFERENT_WITHIN_TOL
49+
assert vector != VECTOR_FLOAT_DIFFERENT_OUTSIDE_TOL
50+
51+
52+
def test_vector_norm():
53+
vector = Vector3D(VECTOR_FLOAT)
54+
# Check if the norm is close to the expected value to avoid architecture-specific issues
55+
np.isclose(vector.norm, VECTOR_FLOAT_NORM, atol=FLOAT_PRECISION, rtol=0)
56+
assert vector.value == VECTOR_FLOAT
57+
assert vector.x == 1.234567890
58+
assert vector.y == 2.345678901
59+
assert vector.z == 3.456789012
60+
61+
62+
#####################################################################
63+
## RoundedVector3D tests
64+
#####################################################################
65+
66+
4167
def test_rounded_vector_init():
4268
vector = RoundedVector3D(VECTOR_FLOAT)
4369
assert vector.model_dump() == VECTOR_FLOAT
@@ -70,3 +96,19 @@ def test_rounded_vector_serialization():
7096
assert vector.x_rounded == VECTOR_FLOAT_ROUNDED_3[0]
7197
assert vector.y_rounded == VECTOR_FLOAT_ROUNDED_3[1]
7298
assert vector.z_rounded == VECTOR_FLOAT_ROUNDED_3[2]
99+
100+
101+
def test_rounded_vector_equality():
102+
class_reference = RoundedVector3D
103+
# Higher precision yields inequality
104+
class_reference.__round_precision__ = 4
105+
vector = class_reference(VECTOR_FLOAT)
106+
assert vector == VECTOR_FLOAT
107+
assert vector == VECTOR_FLOAT_ROUNDED_4
108+
assert vector != VECTOR_FLOAT_ROUNDED_3
109+
# Lower precision yields equality
110+
class_reference.__round_precision__ = 3
111+
vector = class_reference(VECTOR_FLOAT)
112+
assert vector == VECTOR_FLOAT
113+
assert vector == VECTOR_FLOAT_ROUNDED_4
114+
assert vector == VECTOR_FLOAT_ROUNDED_3

0 commit comments

Comments
 (0)