diff --git a/src/pals/kinds/Quadrupole.py b/src/pals/kinds/Quadrupole.py index ac63326..d5da87c 100644 --- a/src/pals/kinds/Quadrupole.py +++ b/src/pals/kinds/Quadrupole.py @@ -1,5 +1,7 @@ from typing import Literal, Optional +from pydantic import model_validator + from .mixin import ThickElement from ..parameters import MagneticMultipoleParameters, ElectricMultipoleParameters @@ -11,5 +13,14 @@ class Quadrupole(ThickElement): kind: Literal["Quadrupole"] = "Quadrupole" # Quadrupole-specific parameters + MagneticMultipoleP: Optional[MagneticMultipoleParameters] = None ElectricMultipoleP: Optional[ElectricMultipoleParameters] = None - MagneticMultipoleP: MagneticMultipoleParameters + + @model_validator(mode="after") + def validate_at_least_one_multipole(self) -> "Quadrupole": + """Ensure at least one multipole parameter is specified.""" + if self.MagneticMultipoleP is None and self.ElectricMultipoleP is None: + raise ValueError( + "At least one of 'MagneticMultipoleP' or 'ElectricMultipoleP' must be specified" + ) + return self diff --git a/src/pals/parameters/ElectricMultipoleParameters.py b/src/pals/parameters/ElectricMultipoleParameters.py index 52fe74f..af16624 100644 --- a/src/pals/parameters/ElectricMultipoleParameters.py +++ b/src/pals/parameters/ElectricMultipoleParameters.py @@ -1,11 +1,63 @@ -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator +from typing import Any + +# Valid parameter prefixes, their expected format and description +_PARAMETER_PREFIXES = { + "tilt": ("tiltN", "Tilt"), + "En": ("EnN", "Normal component"), + "Es": ("EsN", "Skew component"), +} + + +def _validate_order( + key_num: str, parameter_name: str, prefix: str, expected_format: str +) -> None: + """Validate that the order number is a non-negative integer without leading zeros.""" + error_msg = ( + f"Invalid {parameter_name}: '{prefix}{key_num}'. " + f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros." + ) + if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): + raise ValueError(error_msg) class ElectricMultipoleParameters(BaseModel): - """Electric multipole parameters""" + """Electric multipole parameters + + Valid parameter formats: + - tiltN: Tilt of Nth order multipole + - EnN: Normal component of Nth order multipole + - EsN: Skew component of Nth order multipole + - *NL: Length-integrated versions of components (e.g., En3L, EsNL) + + Where N is a positive integer without leading zeros (except "0" itself). + """ - # Allow arbitrary fields (TODO: remove this) model_config = ConfigDict(extra="allow") - # TODO: add ElectricMultipoleParameters in a follow-up RP - # https://pals-project.readthedocs.io/en/latest/element-parameters.html#electricmultipolep-electric-multipole-parameters + @model_validator(mode="before") + @classmethod + def validate(cls, values: dict[str, Any]) -> dict[str, Any]: + """Validate all parameter names match the expected multipole format.""" + for key in values: + # Check if key ends with 'L' for length-integrated values + is_length_integrated = key.endswith("L") + base_key = key[:-1] if is_length_integrated else key + + # No length-integrated values allowed for tilt parameter + if is_length_integrated and base_key.startswith("tilt"): + raise ValueError(f"Invalid electric multipole parameter: '{key}'. ") + + # Find matching prefix + for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items(): + if base_key.startswith(prefix): + key_num = base_key[len(prefix) :] + _validate_order(key_num, description, prefix, expected_format) + break + else: + raise ValueError( + f"Invalid electric multipole parameter: '{key}'. " + f"Parameters must be of the form 'tiltN', 'EnN', or 'EsN' " + f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer." + ) + return values diff --git a/tests/test_elements.py b/tests/test_elements.py index f8ebb93..b2ee125 100644 --- a/tests/test_elements.py +++ b/tests/test_elements.py @@ -55,6 +55,7 @@ def test_Quadrupole(): # Create one drift element with custom name and length element_name = "quadrupole_element" element_length = 1.0 + # Magnetic multipole parameters element_magnetic_multipole_Bn1 = 1.1 element_magnetic_multipole_Bn2 = 1.2 element_magnetic_multipole_Bs1 = 2.1 @@ -69,10 +70,26 @@ def test_Quadrupole(): Bs2=element_magnetic_multipole_Bs2, tilt2=element_magnetic_multipole_tilt2, ) + # Electric multipole parameters + element_electric_multipole_En1 = 1.1 + element_electric_multipole_En2 = 1.2 + element_electric_multipole_Es1 = 2.1 + element_electric_multipole_Es2 = 2.2 + element_electric_multipole_tilt1 = 3.1 + element_electric_multipole_tilt2 = 3.2 + element_electric_multipole = pals.ElectricMultipoleParameters( + En1=element_electric_multipole_En1, + Es1=element_electric_multipole_Es1, + tilt1=element_electric_multipole_tilt1, + En2=element_electric_multipole_En2, + Es2=element_electric_multipole_Es2, + tilt2=element_electric_multipole_tilt2, + ) element = pals.Quadrupole( name=element_name, length=element_length, MagneticMultipoleP=element_magnetic_multipole, + ElectricMultipoleP=element_electric_multipole, ) assert element.name == element_name assert element.length == element_length @@ -82,6 +99,12 @@ def test_Quadrupole(): assert element.MagneticMultipoleP.Bn2 == element_magnetic_multipole_Bn2 assert element.MagneticMultipoleP.Bs2 == element_magnetic_multipole_Bs2 assert element.MagneticMultipoleP.tilt2 == element_magnetic_multipole_tilt2 + assert element.ElectricMultipoleP.En1 == element_electric_multipole_En1 + assert element.ElectricMultipoleP.Es1 == element_electric_multipole_Es1 + assert element.ElectricMultipoleP.tilt1 == element_electric_multipole_tilt1 + assert element.ElectricMultipoleP.En2 == element_electric_multipole_En2 + assert element.ElectricMultipoleP.Es2 == element_electric_multipole_Es2 + assert element.ElectricMultipoleP.tilt2 == element_electric_multipole_tilt2 # Serialize the BeamLine object to YAML yaml_data = yaml.dump(element.model_dump(), default_flow_style=False) print(f"\n{yaml_data}") @@ -117,12 +140,14 @@ def test_Sextupole(): name="sext1", length=0.5, MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn2=1.0), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En2=1.0), ApertureP=pals.ApertureParameters(x_limits=[-0.1, 0.1]), ) assert element.name == "sext1" assert element.length == 0.5 assert element.kind == "Sextupole" assert element.MagneticMultipoleP.Bn2 == 1.0 + assert element.ElectricMultipoleP.En2 == 1.0 assert element.ApertureP.x_limits == [-0.1, 0.1] @@ -131,12 +156,14 @@ def test_Octupole(): element = pals.Octupole( name="oct1", length=0.3, + MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn3=0.5), ElectricMultipoleP=pals.ElectricMultipoleParameters(En3=0.5), MetaP=pals.MetaParameters(alias="octupole_test"), ) assert element.name == "oct1" assert element.length == 0.3 assert element.kind == "Octupole" + assert element.MagneticMultipoleP.Bn3 == 0.5 assert element.ElectricMultipoleP.En3 == 0.5 assert element.MetaP.alias == "octupole_test" @@ -147,12 +174,16 @@ def test_Multipole(): name="mult1", length=0.4, MagneticMultipoleP=pals.MagneticMultipoleParameters(Bn1=2.0, Bn2=1.5), + ElectricMultipoleP=pals.ElectricMultipoleParameters(En1=2.0, En2=1.5), BodyShiftP=pals.BodyShiftParameters(x_offset=0.01), ) assert element.name == "mult1" assert element.length == 0.4 assert element.kind == "Multipole" assert element.MagneticMultipoleP.Bn1 == 2.0 + assert element.MagneticMultipoleP.Bn2 == 1.5 + assert element.ElectricMultipoleP.En1 == 2.0 + assert element.ElectricMultipoleP.En2 == 1.5 assert element.BodyShiftP.x_offset == 0.01 diff --git a/tests/test_parameters.py b/tests/test_parameters.py index ba904e9..09416bc 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -6,6 +6,7 @@ BeamBeamParameters, BendParameters, BodyShiftParameters, + ElectricMultipoleParameters, FloorShiftParameters, ForkParameters, MagneticMultipoleParameters, @@ -38,9 +39,27 @@ def test_ParameterClasses(): meta = MetaParameters(alias="test", description="test element") assert meta.alias == "test" - # Test ElectricMultipoleParameters (TODO) - # emp = ElectricMultipoleParameters(En1=1.0, Es1=0.5) - # assert emp.En1 == 1.0 + # Test ElectricMultipoleParameters + emp = ElectricMultipoleParameters(tilt1=1.2, En1=1.0, Es1=0.5) + assert emp.tilt1 == 1.2 + assert emp.En1 == 1.0 + assert emp.Es1 == 0.5 + + emp2 = ElectricMultipoleParameters(En1L=1.0, Es1L=0.5) + assert emp2.En1L == 1.0 + assert emp2.Es1L == 0.5 + + # catch typos + with pytest.raises(ValidationError): + _ = ElectricMultipoleParameters(Em1=1.0, Es1=0.5) + with pytest.raises(ValidationError): + _ = ElectricMultipoleParameters(En1=1.0, Ev1=0.5) + with pytest.raises(ValidationError): + _ = ElectricMultipoleParameters(En01=1.0, Es01=0.5) + with pytest.raises(ValidationError): + _ = ElectricMultipoleParameters(En1v=1.0, Es1l=0.5) + with pytest.raises(ValidationError): + _ = ElectricMultipoleParameters(tilt1L=1.2) # Test MagneticMultipoleParameters mmp = MagneticMultipoleParameters(tilt1=1.2, Bn1=1.0, Bs1=0.5)