1414
1515from abc import ABC , abstractmethod
1616from collections .abc import Callable
17- from dataclasses import dataclass
17+ from dataclasses import dataclass , is_dataclass
1818from functools import wraps
19- from typing import TYPE_CHECKING , Any , ClassVar , Protocol , runtime_checkable
19+ from inspect import isfunction
20+ from typing import (
21+ TYPE_CHECKING ,
22+ Any ,
23+ ClassVar ,
24+ ParamSpec ,
25+ )
2026
2127from pysatl_core .types import ParametrizationName
2228
2329if TYPE_CHECKING :
2430 from pysatl_core .families .parametric_family import ParametricFamily
2531
2632
27- @runtime_checkable
28- class ParametrizationConstraintProtocol (Protocol ):
29- @property
30- def _is_constraint (self ) -> bool : ...
31- @property
32- def _constraint_description (self ) -> str : ...
33-
34- def __call__ (self , ** kwargs : Any ) -> bool : ...
35-
36-
3733@dataclass (slots = True , frozen = True )
3834class ParametrizationConstraint :
3935 """
@@ -64,6 +60,10 @@ class Parametrization(ABC):
6460 Class-level list of constraints that apply to this parametrization.
6561 """
6662
63+ # These class attributes are set by the @parametrization decorator.
64+ __family__ : ClassVar [ParametricFamily ]
65+ __param_name__ : ClassVar [ParametrizationName ]
66+
6767 _constraints : ClassVar [list [ParametrizationConstraint ]] = []
6868
6969 @property
@@ -166,7 +166,7 @@ def base(self) -> type[Parametrization]:
166166 Raises
167167 ------
168168 ValueError
169- If no base parametrization has been defined.
169+ If no base parametrization has been defined or registered .
170170 """
171171 if self .base_parametrization_name is None :
172172 raise ValueError ("No base parametrization defined" )
@@ -209,12 +209,16 @@ def get_base_parameters(self, parameters: Parametrization) -> Parametrization:
209209 return parameters .transform_to_base_parametrization ()
210210
211211
212- # Decorators for declarative syntax
213- def constraint (
214- description : str ,
215- ) -> Callable [[Callable [[ Any ] , bool ]], ParametrizationConstraintProtocol ]:
212+ P = ParamSpec ( "P" )
213+
214+
215+ def constraint ( description : str ) -> Callable [[Callable [P , bool ]], Callable [ P , bool ] ]:
216216 """
217- Decorator to mark a method as a parameter constraint.
217+ Decorator to mark an instance method as a parameter constraint.
218+
219+ The decorated function must be a predicate returning ``bool``. At class
220+ decoration time it will be discovered and attached as a
221+ :class:`ParametrizationConstraint`.
218222
219223 Parameters
220224 ----------
@@ -223,30 +227,44 @@ def constraint(
223227
224228 Returns
225229 -------
226- Callable
227- Decorator function that marks the method as a constraint.
230+ Callable[[Callable[P, bool]], Callable[P, bool]]
231+ A decorator that returns the function wrapper with two marker
232+ attributes set on it.
233+
234+ Notes
235+ -----
236+ The following marker attributes are set on the resulting function object:
237+
238+ * ``__is_constraint`` : ``True``
239+ * ``__constraint_description`` : ``str``
228240
229241 Examples
230242 --------
231- >>> @constraint("sigma > 0")
232- >>> def check_sigma_positive(self):
233- >>> return self.sigma > 0
243+ >>> class MeanStd(Parametrization):
244+ ... mean: float
245+ ... sigma: float
246+ ...
247+ ... @constraint("sigma > 0")
248+ ... def _c_sigma_positive(self) -> bool:
249+ ... return self.sigma > 0
234250 """
235251
236- def decorator (func : Callable [[ Any ] , bool ]) -> ParametrizationConstraintProtocol :
252+ def decorator (func : Callable [P , bool ]) -> Callable [ P , bool ] :
237253 @wraps (func )
238- def wrapper (* args , ** kwargs ): # type: ignore
254+ def wrapper (* args : P . args , ** kwargs : P . kwargs ) -> bool :
239255 return func (* args , ** kwargs )
240256
241- wrapper . _is_constraint = True # type: ignore
242- wrapper . _constraint_description = description # type: ignore
243- return wrapper # type: ignore
257+ setattr ( wrapper , "__is_constraint" , True )
258+ setattr ( wrapper , "__constraint_description" , description )
259+ return wrapper
244260
245261 return decorator
246262
247263
248264def parametrization (
249- family : ParametricFamily , name : str
265+ * ,
266+ family : ParametricFamily ,
267+ name : str ,
250268) -> Callable [[type [Parametrization ]], type [Parametrization ]]:
251269 """
252270 Decorator to register a class as a parametrization for a family.
@@ -257,58 +275,72 @@ def parametrization(
257275 The family to register the parametrization with.
258276 name : str
259277 Name of the parametrization.
260- base : bool, optional
261- Whether this is the base parametrization, by default False.
262278
263279 Returns
264280 -------
265- Callable
266- Decorator function that registers the class as a parametrization .
281+ Callable[[type[Parametrization]], type[Parametrization]]
282+ A class decorator that registers the parametrization and returns the class .
267283
268284 Examples
269285 --------
270- >>> @parametrization(family=NormalFamily , name='meanvar' )
271- >>> class MeanVarParametrization :
272- >>> mean: float
273- >>> var: float
286+ >>> @parametrization(family=normal , name="mean_var" )
287+ ... class MeanVar(Parametrization) :
288+ ... mean: float
289+ ... var: float
274290 """
275291
276- def decorator (cls : type [Parametrization ]) -> type [Parametrization ]:
277- # Convert to dataclass if not already
278- if not hasattr (cls , "__dataclass_fields__" ):
279- cls = dataclass (cls )
280-
281- # Add name property
282- def name_property (self ): # type: ignore
283- return name
284-
285- cls .name = property (name_property ) # type: ignore
286-
287- # Add parameters property
288- def parameters_property (self ): # type: ignore
289- return {
290- field .name : getattr (self , field .name )
291- for field in self .__dataclass_fields__ .values ()
292- }
293-
294- cls .parameters = property (parameters_property ) # type: ignore
295-
296- # Collect constraints
297- constraints = []
298- for attr_name in dir (cls ):
299- attr = getattr (cls , attr_name )
300- if hasattr (attr , "_is_constraint" ) and attr ._is_constraint :
301- constraints .append (
302- ParametrizationConstraint (description = attr ._constraint_description , check = attr )
292+ def _collect_constraints (cls : type [Parametrization ]) -> list [ParametrizationConstraint ]:
293+ """
294+ Collect constraint methods declared on the class.
295+
296+ Parameters
297+ ----------
298+ cls : type[Parametrization]
299+ Class being registered as a parametrization.
300+
301+ Returns
302+ -------
303+ list[ParametrizationConstraint]
304+ Collected constraints in declaration order.
305+
306+ Raises
307+ ------
308+ TypeError
309+ If a constraint is declared as ``@staticmethod`` or ``@classmethod``.
310+ """
311+ constraints : list [ParametrizationConstraint ] = []
312+ for name , attr in cls .__dict__ .items ():
313+ if isinstance (attr , staticmethod ):
314+ raise TypeError (
315+ f"@constraint '{ name } ' must be an instance method, not @staticmethod"
316+ )
317+ if isinstance (attr , classmethod ):
318+ raise TypeError (
319+ f"@constraint '{ name } ' must be an instance method, not @classmethod"
303320 )
304- cls ._constraints = constraints
305321
306- # Add validate method
307- cls .validate = Parametrization .validate # type: ignore
322+ func = attr if callable (attr ) and isfunction (attr ) else None
323+ if not func :
324+ continue
325+ if getattr (func , "__is_constraint" , False ):
326+ desc = getattr (func , "__constraint_description" , func .__name__ )
327+ constraints .append (ParametrizationConstraint (description = desc , check = func ))
328+ return constraints
308329
309- # Register with family
310- family .parametrizations .add_parametrization (name , cls )
330+ def decorator (cls : type [Parametrization ]) -> type [Parametrization ]:
331+ if not is_dataclass (cls ):
332+ cls = dataclass (slots = True , frozen = True )(cls )
333+
334+ # Attach metadata expected by tooling; declared in base class for mypy.
335+ cls .__family__ = family
336+ cls .__param_name__ = name
311337
338+ # Discover and store constraints.
339+ constraints = _collect_constraints (cls )
340+ cls ._constraints = constraints
341+
342+ # Register in the family's spec.
343+ family .parametrizations .add_parametrization (name , cls )
312344 return cls
313345
314346 return decorator
0 commit comments