|
3 | 3 |
|
4 | 4 | This module provides a framework for building and managing probabilistic models over concepts. |
5 | 5 | """ |
6 | | -import copy |
7 | 6 | import inspect |
8 | 7 |
|
9 | 8 | from torch import nn |
@@ -152,23 +151,16 @@ def _initialize_model(self, input_parametric_cpds: List[ParametricCPD]): |
152 | 151 |
|
153 | 152 | # ---- ParametricCPD modules: fill only self.parametric_cpds (ModuleDict) ---- |
154 | 153 | for parametric_cpd in input_parametric_cpds: |
155 | | - if len(parametric_cpd.concepts) > 1: |
156 | | - # Multi-concept parametric_cpd: split into individual CPDs |
157 | | - for concept in parametric_cpd.concepts: |
158 | | - new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=copy.deepcopy(parametric_cpd.parametrization)) |
159 | | - # Link the parametric_cpd to its variable |
160 | | - if concept in self.concept_to_variable: |
161 | | - new_parametric_cpd.variable = self.concept_to_variable[concept] |
162 | | - new_parametric_cpd.parents = self.concept_to_variable[concept].parents |
163 | | - self.parametric_cpds[concept] = new_parametric_cpd |
164 | | - else: |
165 | | - # Single concept parametric_cpd |
166 | | - concept = parametric_cpd.concepts[0] |
| 154 | + for concept in parametric_cpd.concepts: |
167 | 155 | # Link the parametric_cpd to its variable |
168 | 156 | if concept in self.concept_to_variable: |
169 | 157 | parametric_cpd.variable = self.concept_to_variable[concept] |
170 | 158 | parametric_cpd.parents = self.concept_to_variable[concept].parents |
171 | | - self.parametric_cpds[concept] = parametric_cpd |
| 159 | + new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization, |
| 160 | + 'out_features', |
| 161 | + self.concept_to_variable[concept].size) |
| 162 | + new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization) |
| 163 | + self.parametric_cpds[concept] = new_parametric_cpd |
172 | 164 |
|
173 | 165 | # ---- Parent resolution (unchanged) ---- |
174 | 166 | for var in self.variables: |
|
0 commit comments