Skip to content

Commit bc851e1

Browse files
committed
fix: update example
1 parent 4e0f373 commit bc851e1

File tree

4 files changed

+64
-71
lines changed

4 files changed

+64
-71
lines changed

examples/example-parametric.ipynb

Lines changed: 58 additions & 23 deletions
Large diffs are not rendered by default.

src/pysatl_core/families/distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def analytical_computations(
8787
cache_val = getattr(self, "_analytical_cache_val", None)
8888

8989
if cache_key != key or cache_val is None:
90-
cache_val = self.family.build_analytical_computations(self.parameters)
90+
cache_val = self.family._build_analytical_computations(self.parameters)
9191
self._analytical_cache_key = key
9292
self._analytical_cache_val = cache_val
9393

src/pysatl_core/families/parametric_family.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,10 @@ def get_parametrization(self, name: ParametrizationName) -> type[Parametrization
231231
"""
232232
return self._parametrizations[name]
233233

234-
def get_base_parameters(self, parameters: Parametrization) -> Parametrization:
234+
def to_base(self, parameters: Parametrization) -> Parametrization:
235235
"""
236236
Convert parameters to the base parametrization.
237237
238-
This method mirrors the former ``ParametrizationSpec.get_base_parameters``.
239-
240238
Parameters
241239
----------
242240
parameters : Parametrization
@@ -280,47 +278,7 @@ def _build_analytical_computations(
280278
params_obj = parameters
281279
else:
282280
if base_params is None:
283-
base_params = self.get_base_parameters(parameters)
284-
params_obj = base_params
285-
286-
func_factory = self.distr_characteristics[characteristic][provider_name]
287-
result[characteristic] = AnalyticalComputation(
288-
target=characteristic,
289-
func=partial(func_factory, params_obj),
290-
)
291-
292-
return result
293-
294-
def build_analytical_computations(
295-
self, parameters: Parametrization
296-
) -> dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]]:
297-
"""
298-
Build analytical computations mapping for the given parameter instance.
299-
300-
This uses a precomputed provider plan so runtime work is reduced to:
301-
- (Optionally) converting parameters to base once,
302-
- binding callables with :func:`functools.partial`.
303-
304-
Parameters
305-
----------
306-
parameters : Parametrization
307-
Parameters in any registered parametrization.
308-
309-
Returns
310-
-------
311-
dict[GenericCharacteristicName, AnalyticalComputation]
312-
Mapping from characteristic name to analytical computation callable.
313-
"""
314-
plan = self._analytical_plan.get(parameters.name, {})
315-
result: dict[GenericCharacteristicName, AnalyticalComputation[Any, Any]] = {}
316-
base_params: Parametrization | None = None
317-
318-
for characteristic, provider_name in plan.items():
319-
if provider_name == parameters.name:
320-
params_obj = parameters
321-
else:
322-
if base_params is None:
323-
base_params = self.get_base_parameters(parameters)
281+
base_params = self.to_base(parameters)
324282
params_obj = base_params
325283

326284
func_factory = self.distr_characteristics[characteristic][provider_name]
@@ -364,7 +322,7 @@ def distribution(
364322
parametrization_class = self._parametrizations[parametrization_name]
365323

366324
parameters = parametrization_class(**parameters_values)
367-
base_parameters = self.get_base_parameters(parameters)
325+
base_parameters = self.to_base(parameters)
368326
parameters.validate()
369327
distribution_type = self._distr_type(base_parameters)
370328
return ParametricFamilyDistribution(self.name, distribution_type, parameters)

tests/unit/families/test_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def test_get_base_parameters_uses_family_logic(self) -> None:
7777
AltCls = family.parametrizations["alt"]
7878

7979
base_params = BaseCls(value=5.0) # type: ignore[call-arg]
80-
assert family.get_base_parameters(base_params) is base_params
80+
assert family.to_base(base_params) is base_params
8181

8282
alt_params = AltCls(value=3.0) # type: ignore[call-arg]
83-
base_from_alt = family.get_base_parameters(alt_params)
83+
base_from_alt = family.to_base(alt_params)
8484
assert isinstance(base_from_alt, BaseCls)
8585
# Our default factory maps Alt(value=v) → Base(value=v)
8686
assert base_from_alt.value == 3.0 # type: ignore[attr-defined]

0 commit comments

Comments
 (0)