Skip to content

Commit 2ddf2a7

Browse files
ax3lEZoni
andauthored
Refactor MagneticMultipoleParameters (#44)
In preparation of #38, refactor `MagneticMultipoleParameters`. - remove code duplication and hard-coded string accesses - more detailed error messages - add support for `Kn` and `Ks` and length-integrated `...L` parameters --------- Co-authored-by: Edoardo Zoni <ezoni@lbl.gov>
1 parent 834f322 commit 2ddf2a7

File tree

2 files changed

+69
-51
lines changed

2 files changed

+69
-51
lines changed
Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,67 @@
11
from pydantic import BaseModel, ConfigDict, model_validator
2-
from typing import Any, Dict
2+
from typing import Any
3+
4+
# Valid parameter prefixes, their expected format and description
5+
_PARAMETER_PREFIXES = {
6+
"tilt": ("tiltN", "Tilt"),
7+
"Bn": ("BnN", "Normal component"),
8+
"Bs": ("BsN", "Skew component"),
9+
"Kn": ("KnN", "Normalized normal component"),
10+
"Ks": ("KsN", "Normalized skew component"),
11+
}
12+
13+
14+
def _validate_order(
15+
key_num: str, parameter_name: str, prefix: str, expected_format: str
16+
) -> None:
17+
"""Validate that the order number is a non-negative integer without leading zeros."""
18+
error_msg = (
19+
f"Invalid {parameter_name}: '{prefix}{key_num}'. "
20+
f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros."
21+
)
22+
if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"):
23+
raise ValueError(error_msg)
324

425

526
class MagneticMultipoleParameters(BaseModel):
6-
"""Magnetic multipole parameters"""
27+
"""Magnetic multipole parameters
728
8-
# Allow arbitrary fields
9-
model_config = ConfigDict(extra="allow")
29+
Valid parameter formats:
30+
- tiltN: Tilt of Nth order multipole
31+
- BnN: Normal component of Nth order multipole
32+
- BsN: Skew component of Nth order multipole
33+
- KnN: Normalized normal component of Nth order multipole
34+
- KsN: Normalized skew component of Nth order multipole
35+
- *NL: Length-integrated versions of components (e.g., Bn3L, KsNL)
36+
37+
Where N is a positive integer without leading zeros (except "0" itself).
38+
"""
1039

11-
# Custom validation of magnetic multipole order
12-
def _validate_order(key_num, msg):
13-
if key_num.isdigit():
14-
if key_num.startswith("0") and key_num != "0":
15-
raise ValueError(msg)
16-
else:
17-
raise ValueError(msg)
40+
model_config = ConfigDict(extra="allow")
1841

19-
# Custom validation to be applied before standard validation
2042
@model_validator(mode="before")
21-
def validate(cls, values: Dict[str, Any]) -> Dict[str, Any]:
22-
# loop over all attributes
43+
@classmethod
44+
def validate(cls, values: dict[str, Any]) -> dict[str, Any]:
45+
"""Validate all parameter names match the expected multipole format."""
2346
for key in values:
24-
# validate tilt parameters 'tiltN'
25-
if key.startswith("tilt"):
26-
key_num = key[4:]
27-
msg = " ".join(
28-
[
29-
f"Invalid tilt parameter: '{key}'.",
30-
"Tilt parameter must be of the form 'tiltN', where 'N' is an integer.",
31-
]
32-
)
33-
cls._validate_order(key_num, msg)
34-
# validate normal component parameters 'BnN'
35-
elif key.startswith("Bn"):
36-
key_num = key[2:]
37-
msg = " ".join(
38-
[
39-
f"Invalid normal component parameter: '{key}'.",
40-
"Normal component parameter must be of the form 'BnN', where 'N' is an integer.",
41-
]
42-
)
43-
cls._validate_order(key_num, msg)
44-
# validate skew component parameters 'BsN'
45-
elif key.startswith("Bs"):
46-
key_num = key[2:]
47-
msg = " ".join(
48-
[
49-
f"Invalid skew component parameter: '{key}'.",
50-
"Skew component parameter must be of the form 'BsN', where 'N' is an integer.",
51-
]
52-
)
53-
cls._validate_order(key_num, msg)
47+
# Check if key ends with 'L' for length-integrated values
48+
is_length_integrated = key.endswith("L")
49+
base_key = key[:-1] if is_length_integrated else key
50+
51+
# No length-integrated values allowed for tilt parameter
52+
if is_length_integrated and base_key.startswith("tilt"):
53+
raise ValueError(f"Invalid magnetic multipole parameter: '{key}'. ")
54+
55+
# Find matching prefix
56+
for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items():
57+
if base_key.startswith(prefix):
58+
key_num = base_key[len(prefix) :]
59+
_validate_order(key_num, description, prefix, expected_format)
60+
break
5461
else:
55-
msg = " ".join(
56-
[
57-
f"Invalid magnetic multipole parameter: '{key}'.",
58-
"Magnetic multipole parameters must be of the form 'tiltN', 'BnN', or 'BsN', where 'N' is an integer.",
59-
]
62+
raise ValueError(
63+
f"Invalid magnetic multipole parameter: '{key}'. "
64+
f"Parameters must be of the form 'tiltN', 'BnN', 'BsN', 'KnN', or 'KsN' "
65+
f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer."
6066
)
61-
raise ValueError(msg)
6267
return values

tests/test_parameters.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,30 @@ def test_ParameterClasses():
4343
# assert emp.En1 == 1.0
4444

4545
# Test MagneticMultipoleParameters
46-
mmp = MagneticMultipoleParameters(Bn1=1.0, Bs1=0.5)
46+
mmp = MagneticMultipoleParameters(tilt1=1.2, Bn1=1.0, Bs1=0.5)
47+
assert mmp.tilt1 == 1.2
4748
assert mmp.Bn1 == 1.0
4849
assert mmp.Bs1 == 0.5
4950

51+
mmp2 = MagneticMultipoleParameters(Kn0=1.0, Ks1=0.5)
52+
assert mmp2.Kn0 == 1.0
53+
assert mmp2.Ks1 == 0.5
54+
55+
mmp3 = MagneticMultipoleParameters(Bn1L=1.0, Bs1L=0.5)
56+
assert mmp3.Bn1L == 1.0
57+
assert mmp3.Bs1L == 0.5
58+
5059
# catch typos
5160
with pytest.raises(ValidationError):
5261
_ = MagneticMultipoleParameters(Bm1=1.0, Bs1=0.5)
5362
with pytest.raises(ValidationError):
5463
_ = MagneticMultipoleParameters(Bn1=1.0, Bv1=0.5)
5564
with pytest.raises(ValidationError):
5665
_ = MagneticMultipoleParameters(Bn01=1.0, Bs01=0.5)
66+
with pytest.raises(ValidationError):
67+
_ = MagneticMultipoleParameters(Bn1v=1.0, Bs1l=0.5)
68+
with pytest.raises(ValidationError):
69+
_ = MagneticMultipoleParameters(tilt1L=1.2)
5770

5871
# Test SolenoidParameters
5972
sol = SolenoidParameters(Ksol=0.1, Bsol=0.2)

0 commit comments

Comments
 (0)