Skip to content

Commit bcc4214

Browse files
Fix parametric cpd re-initialization in probabilistic model
1 parent aa62476 commit bcc4214

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

torch_concepts/nn/modules/mid/models/probabilistic_model.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
44
This module provides a framework for building and managing probabilistic models over concepts.
55
"""
6-
import copy
76
import inspect
87

98
from torch import nn
@@ -152,23 +151,16 @@ def _initialize_model(self, input_parametric_cpds: List[ParametricCPD]):
152151

153152
# ---- ParametricCPD modules: fill only self.parametric_cpds (ModuleDict) ----
154153
for parametric_cpd in input_parametric_cpds:
155-
if len(parametric_cpd.concepts) > 1:
156-
# Multi-concept parametric_cpd: split into individual CPDs
157-
for concept in parametric_cpd.concepts:
158-
new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=copy.deepcopy(parametric_cpd.parametrization))
159-
# Link the parametric_cpd to its variable
160-
if concept in self.concept_to_variable:
161-
new_parametric_cpd.variable = self.concept_to_variable[concept]
162-
new_parametric_cpd.parents = self.concept_to_variable[concept].parents
163-
self.parametric_cpds[concept] = new_parametric_cpd
164-
else:
165-
# Single concept parametric_cpd
166-
concept = parametric_cpd.concepts[0]
154+
for concept in parametric_cpd.concepts:
167155
# Link the parametric_cpd to its variable
168156
if concept in self.concept_to_variable:
169157
parametric_cpd.variable = self.concept_to_variable[concept]
170158
parametric_cpd.parents = self.concept_to_variable[concept].parents
171-
self.parametric_cpds[concept] = parametric_cpd
159+
new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization,
160+
'out_features',
161+
self.concept_to_variable[concept].size)
162+
new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization)
163+
self.parametric_cpds[concept] = new_parametric_cpd
172164

173165
# ---- Parent resolution (unchanged) ----
174166
for var in self.variables:

0 commit comments

Comments
 (0)