|
15 | 15 |
|
16 | 16 | from collections.abc import Mapping |
17 | 17 | from dataclasses import dataclass |
18 | | -from functools import partial |
19 | 18 | from typing import TYPE_CHECKING, Any |
20 | 19 |
|
21 | 20 | from pysatl_core.distributions import ( |
@@ -73,37 +72,26 @@ def family(self) -> ParametricFamily: |
73 | 72 | def analytical_computations( |
74 | 73 | self, |
75 | 74 | ) -> Mapping[GenericCharacteristicName, AnalyticalComputation[Any, Any]]: |
76 | | - """ |
77 | | - Get analytical computation functions for this distribution. |
| 75 | + """Lazily computed analytical computations for this distribution instance. |
78 | 76 |
|
79 | | - Returns |
80 | | - ------- |
81 | | - Mapping[GenericCharacteristicName, AnalyticalComputation] |
82 | | - Mapping from characteristic names to computation functions. |
| 77 | + Delegates construction to the parent family (precomputed plan) and |
| 78 | + caches the result per-instance. The cache auto-invalidates when either |
| 79 | + the **parametrization object** changes (by identity) or the |
| 80 | + **parametrization name** changes. |
| 81 | +
|
| 82 | + *If you mutate numeric fields of the same parametrization object*, |
| 83 | + the callables see fresh values because they close over that object. |
83 | 84 | """ |
84 | | - analytical_computations = {} |
85 | | - |
86 | | - # First form list of all characteristics, available from current parametrization |
87 | | - for characteristic, forms in self.family.distr_characteristics.items(): |
88 | | - if self.parameters.name in forms: |
89 | | - analytical_computations[characteristic] = AnalyticalComputation( |
90 | | - target=characteristic, |
91 | | - func=partial(forms[self.parameters.name], self.parameters), |
92 | | - ) |
93 | | - # TODO: Second, apply rule set, for, e.g. approximations |
94 | | - |
95 | | - # Finally, fill other chacteristics |
96 | | - base_name = self.family.parametrizations.base_parametrization_name |
97 | | - base_parameters = self.family.parametrizations.get_base_parameters(self.parameters) |
98 | | - for characteristic, forms in self.family.distr_characteristics.items(): |
99 | | - if characteristic in analytical_computations: |
100 | | - continue |
101 | | - if base_name in forms: |
102 | | - analytical_computations[characteristic] = AnalyticalComputation( |
103 | | - target=characteristic, func=partial(forms[base_name], base_parameters) |
104 | | - ) |
105 | | - |
106 | | - return analytical_computations |
| 85 | + key = (id(self.parameters), self.parameters.name) |
| 86 | + cache_key = getattr(self, "_analytical_cache_key", None) |
| 87 | + cache_val = getattr(self, "_analytical_cache_val", None) |
| 88 | + |
| 89 | + if cache_key != key or cache_val is None: |
| 90 | + cache_val = self.family._build_analytical_computations(self.parameters) |
| 91 | + self._analytical_cache_key = key |
| 92 | + self._analytical_cache_val = cache_val |
| 93 | + |
| 94 | + return cache_val |
107 | 95 |
|
108 | 96 | @property |
109 | 97 | def sampling_strategy(self) -> SamplingStrategy: |
|
0 commit comments