Skip to content

Commit 3711af2

Browse files
committed
refactor:refactor parametrizations decorators, add mypy checks
1 parent 6b3a3d3 commit 3711af2

File tree

2 files changed

+122
-69
lines changed

2 files changed

+122
-69
lines changed

src/pysatl_core/families/parametric_family.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,25 @@ def distribution(
154154
distribution_type = self._distr_type(base_parameters)
155155
return ParametricFamilyDistribution(self.name, distribution_type, parameters)
156156

157+
def parametrization(
158+
self, name: str
159+
) -> Callable[[type[Parametrization]], type[Parametrization]]:
160+
"""
161+
Create a class decorator that registers a parametrization in this family.
162+
163+
Parameters
164+
----------
165+
name : str
166+
Name of the parametrization.
167+
168+
Returns
169+
-------
170+
Callable[[type[TParam]], type[TParam]]
171+
Class decorator that registers the parametrization and returns it.
172+
"""
173+
# local import to avoid import cycle at module import time
174+
from pysatl_core.families.parametrizations import parametrization as _param_deco
175+
176+
return _param_deco(family=self, name=name)
177+
157178
__call__ = distribution

src/pysatl_core/families/parametrizations.py

Lines changed: 101 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,22 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections.abc import Callable
17-
from dataclasses import dataclass
17+
from dataclasses import dataclass, is_dataclass
1818
from functools import wraps
19-
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, runtime_checkable
19+
from inspect import isfunction
20+
from typing import (
21+
TYPE_CHECKING,
22+
Any,
23+
ClassVar,
24+
ParamSpec,
25+
)
2026

2127
from pysatl_core.types import ParametrizationName
2228

2329
if TYPE_CHECKING:
2430
from pysatl_core.families.parametric_family import ParametricFamily
2531

2632

27-
@runtime_checkable
28-
class ParametrizationConstraintProtocol(Protocol):
29-
@property
30-
def _is_constraint(self) -> bool: ...
31-
@property
32-
def _constraint_description(self) -> str: ...
33-
34-
def __call__(self, **kwargs: Any) -> bool: ...
35-
36-
3733
@dataclass(slots=True, frozen=True)
3834
class ParametrizationConstraint:
3935
"""
@@ -64,6 +60,10 @@ class Parametrization(ABC):
6460
Class-level list of constraints that apply to this parametrization.
6561
"""
6662

63+
# These class attributes are set by the @parametrization decorator.
64+
__family__: ClassVar[ParametricFamily]
65+
__param_name__: ClassVar[ParametrizationName]
66+
6767
_constraints: ClassVar[list[ParametrizationConstraint]] = []
6868

6969
@property
@@ -166,7 +166,7 @@ def base(self) -> type[Parametrization]:
166166
Raises
167167
------
168168
ValueError
169-
If no base parametrization has been defined.
169+
If no base parametrization has been defined or registered.
170170
"""
171171
if self.base_parametrization_name is None:
172172
raise ValueError("No base parametrization defined")
@@ -209,12 +209,16 @@ def get_base_parameters(self, parameters: Parametrization) -> Parametrization:
209209
return parameters.transform_to_base_parametrization()
210210

211211

212-
# Decorators for declarative syntax
213-
def constraint(
214-
description: str,
215-
) -> Callable[[Callable[[Any], bool]], ParametrizationConstraintProtocol]:
212+
P = ParamSpec("P")
213+
214+
215+
def constraint(description: str) -> Callable[[Callable[P, bool]], Callable[P, bool]]:
216216
"""
217-
Decorator to mark a method as a parameter constraint.
217+
Decorator to mark an instance method as a parameter constraint.
218+
219+
The decorated function must be a predicate returning ``bool``. At class
220+
decoration time it will be discovered and attached as a
221+
:class:`ParametrizationConstraint`.
218222
219223
Parameters
220224
----------
@@ -223,30 +227,44 @@ def constraint(
223227
224228
Returns
225229
-------
226-
Callable
227-
Decorator function that marks the method as a constraint.
230+
Callable[[Callable[P, bool]], Callable[P, bool]]
231+
A decorator that returns the function wrapper with two marker
232+
attributes set on it.
233+
234+
Notes
235+
-----
236+
The following marker attributes are set on the resulting function object:
237+
238+
* ``__is_constraint`` : ``True``
239+
* ``__constraint_description`` : ``str``
228240
229241
Examples
230242
--------
231-
>>> @constraint("sigma > 0")
232-
>>> def check_sigma_positive(self):
233-
>>> return self.sigma > 0
243+
>>> class MeanStd(Parametrization):
244+
... mean: float
245+
... sigma: float
246+
...
247+
... @constraint("sigma > 0")
248+
... def _c_sigma_positive(self) -> bool:
249+
... return self.sigma > 0
234250
"""
235251

236-
def decorator(func: Callable[[Any], bool]) -> ParametrizationConstraintProtocol:
252+
def decorator(func: Callable[P, bool]) -> Callable[P, bool]:
237253
@wraps(func)
238-
def wrapper(*args, **kwargs): # type: ignore
254+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bool:
239255
return func(*args, **kwargs)
240256

241-
wrapper._is_constraint = True # type: ignore
242-
wrapper._constraint_description = description # type: ignore
243-
return wrapper # type: ignore
257+
setattr(wrapper, "__is_constraint", True)
258+
setattr(wrapper, "__constraint_description", description)
259+
return wrapper
244260

245261
return decorator
246262

247263

248264
def parametrization(
249-
family: ParametricFamily, name: str
265+
*,
266+
family: ParametricFamily,
267+
name: str,
250268
) -> Callable[[type[Parametrization]], type[Parametrization]]:
251269
"""
252270
Decorator to register a class as a parametrization for a family.
@@ -257,58 +275,72 @@ def parametrization(
257275
The family to register the parametrization with.
258276
name : str
259277
Name of the parametrization.
260-
base : bool, optional
261-
Whether this is the base parametrization, by default False.
262278
263279
Returns
264280
-------
265-
Callable
266-
Decorator function that registers the class as a parametrization.
281+
Callable[[type[Parametrization]], type[Parametrization]]
282+
A class decorator that registers the parametrization and returns the class.
267283
268284
Examples
269285
--------
270-
>>> @parametrization(family=NormalFamily, name='meanvar')
271-
>>> class MeanVarParametrization:
272-
>>> mean: float
273-
>>> var: float
286+
>>> @parametrization(family=normal, name="mean_var")
287+
... class MeanVar(Parametrization):
288+
... mean: float
289+
... var: float
274290
"""
275291

276-
def decorator(cls: type[Parametrization]) -> type[Parametrization]:
277-
# Convert to dataclass if not already
278-
if not hasattr(cls, "__dataclass_fields__"):
279-
cls = dataclass(cls)
280-
281-
# Add name property
282-
def name_property(self): # type: ignore
283-
return name
284-
285-
cls.name = property(name_property) # type: ignore
286-
287-
# Add parameters property
288-
def parameters_property(self): # type: ignore
289-
return {
290-
field.name: getattr(self, field.name)
291-
for field in self.__dataclass_fields__.values()
292-
}
293-
294-
cls.parameters = property(parameters_property) # type: ignore
295-
296-
# Collect constraints
297-
constraints = []
298-
for attr_name in dir(cls):
299-
attr = getattr(cls, attr_name)
300-
if hasattr(attr, "_is_constraint") and attr._is_constraint:
301-
constraints.append(
302-
ParametrizationConstraint(description=attr._constraint_description, check=attr)
292+
def _collect_constraints(cls: type[Parametrization]) -> list[ParametrizationConstraint]:
293+
"""
294+
Collect constraint methods declared on the class.
295+
296+
Parameters
297+
----------
298+
cls : type[Parametrization]
299+
Class being registered as a parametrization.
300+
301+
Returns
302+
-------
303+
list[ParametrizationConstraint]
304+
Collected constraints in declaration order.
305+
306+
Raises
307+
------
308+
TypeError
309+
If a constraint is declared as ``@staticmethod`` or ``@classmethod``.
310+
"""
311+
constraints: list[ParametrizationConstraint] = []
312+
for name, attr in cls.__dict__.items():
313+
if isinstance(attr, staticmethod):
314+
raise TypeError(
315+
f"@constraint '{name}' must be an instance method, not @staticmethod"
316+
)
317+
if isinstance(attr, classmethod):
318+
raise TypeError(
319+
f"@constraint '{name}' must be an instance method, not @classmethod"
303320
)
304-
cls._constraints = constraints
305321

306-
# Add validate method
307-
cls.validate = Parametrization.validate # type: ignore
322+
func = attr if callable(attr) and isfunction(attr) else None
323+
if not func:
324+
continue
325+
if getattr(func, "__is_constraint", False):
326+
desc = getattr(func, "__constraint_description", func.__name__)
327+
constraints.append(ParametrizationConstraint(description=desc, check=func))
328+
return constraints
308329

309-
# Register with family
310-
family.parametrizations.add_parametrization(name, cls)
330+
def decorator(cls: type[Parametrization]) -> type[Parametrization]:
331+
if not is_dataclass(cls):
332+
cls = dataclass(slots=True, frozen=True)(cls)
333+
334+
# Attach metadata expected by tooling; declared in base class for mypy.
335+
cls.__family__ = family
336+
cls.__param_name__ = name
311337

338+
# Discover and store constraints.
339+
constraints = _collect_constraints(cls)
340+
cls._constraints = constraints
341+
342+
# Register in the family's spec.
343+
family.parametrizations.add_parametrization(name, cls)
312344
return cls
313345

314346
return decorator

0 commit comments

Comments
 (0)