Skip to content

Commit ca4df05

Browse files
Allow to instantiate parametricCPD with lazy constructors, otherwise assume the layer is already correctly and fully instantiated
1 parent c173033 commit ca4df05

File tree

5 files changed

+41
-39
lines changed

5 files changed

+41
-39
lines changed

examples/utilization/1_pgm/0_concept_bottleneck_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_concepts import Annotations, AxisAnnotation, Variable, InputVariable, EndogenousVariable
66
from torch_concepts.data.datasets import ToyDataset
77
from torch_concepts.nn import LinearZC, LinearCC, ParametricCPD, ProbabilisticModel, \
8-
RandomPolicy, DoIntervention, intervention, DeterministicInference
8+
RandomPolicy, DoIntervention, intervention, DeterministicInference, LazyConstructor
99

1010

1111
def main():
@@ -30,9 +30,9 @@ def main():
3030
tasks = EndogenousVariable("xor", parents=concept_names, distribution=RelaxedOneHotCategorical, size=2)
3131

3232
# ParametricCPD setup
33-
backbone = ParametricCPD("input", parametrization=torch.nn.Identity())
34-
c_encoder = ParametricCPD(["c1", "c2"], parametrization=LinearZC(in_features=x_train.shape[1], out_features=concepts[0].size))
35-
y_predictor = ParametricCPD("xor", parametrization=LinearCC(in_features_endogenous=sum(c.size for c in concepts), out_features=tasks.size))
33+
backbone = ParametricCPD("input", parametrization=torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], latent_dims), torch.nn.LeakyReLU()))
34+
c_encoder = ParametricCPD(["c1", "c2"], parametrization=LazyConstructor(LinearZC))
35+
y_predictor = ParametricCPD("xor", parametrization=LazyConstructor(LinearCC))
3636

3737
# ProbabilisticModel Initialization
3838
concept_model = ProbabilisticModel(variables=[input_var, *concepts, tasks], parametric_cpds=[backbone, *c_encoder, y_predictor])

examples/utilization/1_pgm/1_concept_bottleneck_model_ancestral_sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_concepts import Annotations, AxisAnnotation, Variable, InputVariable, EndogenousVariable
66
from torch_concepts.data.datasets import ToyDataset
77
from torch_concepts.nn import LinearZC, LinearCC, ParametricCPD, ProbabilisticModel, \
8-
RandomPolicy, DoIntervention, intervention, AncestralSamplingInference
8+
RandomPolicy, DoIntervention, intervention, AncestralSamplingInference, LazyConstructor
99

1010

1111
def main():
@@ -24,14 +24,14 @@ def main():
2424
y_train = torch.cat([y_train, 1-y_train], dim=1)
2525

2626
# Variable setup
27-
input_var = InputVariable("input", parents=[], size=latent_dims)
27+
input_var = InputVariable("input", parents=[], size=x_train.shape[1])
2828
concepts = EndogenousVariable(concept_names, parents=["input"], distribution=RelaxedBernoulli)
2929
tasks = EndogenousVariable("xor", parents=concept_names, distribution=RelaxedOneHotCategorical, size=2)
3030

3131
# ParametricCPD setup
3232
backbone = ParametricCPD("input", parametrization=torch.nn.Identity())
33-
c_encoder = ParametricCPD(["c1", "c2"], parametrization=LinearZC(in_features=x_train.shape[1], out_features=concepts[0].size))
34-
y_predictor = ParametricCPD("xor", parametrization=LinearCC(in_features_endogenous=sum(c.size for c in concepts), out_features=tasks.size))
33+
c_encoder = ParametricCPD(["c1", "c2"], parametrization=LazyConstructor(LinearZC))
34+
y_predictor = ParametricCPD("xor", parametrization=LazyConstructor(LinearCC))
3535

3636
# ProbabilisticModel Initialization
3737
concept_model = ProbabilisticModel(variables=[input_var, *concepts, tasks], parametric_cpds=[backbone, *c_encoder, y_predictor])

examples/utilization/2_model/4_concept_graph_model_learned.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def main():
5757
source_exogenous=LazyConstructor(LinearZU, exogenous_size=11),
5858
internal_exogenous=LazyConstructor(LinearZU, exogenous_size=7),
5959
encoder=LazyConstructor(LinearUC),
60-
predictor=LazyConstructor(HyperLinearCUC, embedding_size=20),)
60+
predictor=LazyConstructor(HyperLinearCUC, embedding_size=20))
6161

6262
# graph learning init
6363
graph_learner = WANDAGraphLearner(concept_names, task_names)

torch_concepts/nn/modules/mid/constructors/graph.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,7 @@ def _init_exog(self, layer: LazyConstructor, label_names, parent_var, cardinalit
188188
distribution=Delta,
189189
size=layer._module_kwargs['exogenous_size'])
190190

191-
lazy_constructor = layer.build(
192-
in_features=parent_var.size,
193-
in_features_endogenous=None,
194-
in_features_exogenous=None,
195-
out_features=1,
196-
)
197-
198-
exog_cpds = ParametricCPD(exog_names, parametrization=lazy_constructor)
191+
exog_cpds = ParametricCPD(exog_names, parametrization=layer)
199192
return exog_vars, exog_cpds
200193

201194
def _init_encoder(self, layer: LazyConstructor, label_names, parent_vars, cardinalities=None) -> Tuple[Variable, ParametricCPD]:
@@ -220,13 +213,7 @@ def _init_encoder(self, layer: LazyConstructor, label_names, parent_vars, cardin
220213
if not isinstance(encoder_vars, list):
221214
encoder_vars = [encoder_vars]
222215

223-
lazy_constructor = layer.build(
224-
in_features=parent_vars[0].size,
225-
in_features_endogenous=None,
226-
in_features_exogenous=None,
227-
out_features=encoder_vars[0].size,
228-
)
229-
encoder_cpds = ParametricCPD(label_names, parametrization=lazy_constructor)
216+
encoder_cpds = ParametricCPD(label_names, parametrization=layer)
230217
# Ensure encoder_cpds is always a list
231218
if not isinstance(encoder_cpds, list):
232219
encoder_cpds = [encoder_cpds]
@@ -241,13 +228,7 @@ def _init_encoder(self, layer: LazyConstructor, label_names, parent_vars, cardin
241228
parents=exog_vars_names,
242229
distribution=self.annotations[1].metadata[label_name]['distribution'],
243230
size=self.annotations[1].cardinalities[self.annotations[1].get_index(label_name)])
244-
lazy_constructor = layer.build(
245-
in_features=None,
246-
in_features_endogenous=None,
247-
in_features_exogenous=exog_vars[0].size,
248-
out_features=encoder_var.size,
249-
)
250-
encoder_cpd = ParametricCPD(label_name, parametrization=lazy_constructor)
231+
encoder_cpd = ParametricCPD(label_name, parametrization=layer)
251232
encoder_vars.append(encoder_var)
252233
encoder_cpds.append(encoder_cpd)
253234
return encoder_vars, encoder_cpds

torch_concepts/nn/modules/mid/models/probabilistic_model.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from torch.distributions import Distribution
1010
from typing import List, Dict, Optional, Type
1111

12-
from .variable import Variable, ExogenousVariable
12+
from torch_concepts.nn import LazyConstructor
13+
from .variable import Variable, ExogenousVariable, EndogenousVariable, InputVariable
1314
from .cpd import ParametricCPD
1415

1516

@@ -159,14 +160,34 @@ def _initialize_model(self, input_parametric_cpds: List[ParametricCPD]):
159160
if concept in self.concept_to_variable:
160161
parametric_cpd.variable = self.concept_to_variable[concept]
161162
parametric_cpd.parents = self.concept_to_variable[concept].parents
162-
if not isinstance(parametric_cpd.variable, ExogenousVariable):
163-
new_parametrization = _reinitialize_with_new_param(parametric_cpd.parametrization,
164-
'out_features',
165-
self.concept_to_variable[concept].size)
166-
new_parametric_cpd = ParametricCPD(concepts=[concept], parametrization=new_parametrization)
167-
self.parametric_cpds[concept] = new_parametric_cpd
163+
164+
if isinstance(parametric_cpd.parametrization, LazyConstructor):
165+
parent_vars = [self.concept_to_variable[parent_ref] for parent_ref in parametric_cpd.variable.parents]
166+
in_features_endogenous = in_features_exogenous = in_features = 0
167+
for pv in parent_vars:
168+
if isinstance(pv, ExogenousVariable):
169+
in_features_exogenous = pv.size
170+
elif isinstance(pv, EndogenousVariable):
171+
in_features_endogenous += pv.size
172+
else:
173+
in_features += pv.size
174+
175+
if isinstance(parametric_cpd.variable, ExogenousVariable):
176+
out_features = 1
177+
else:
178+
out_features = self.concept_to_variable[concept].size
179+
180+
initialized_layer = parametric_cpd.parametrization.build(
181+
in_features=in_features,
182+
in_features_endogenous=in_features_endogenous,
183+
in_features_exogenous=in_features_exogenous,
184+
out_features=out_features,
185+
)
186+
new_parametrization = ParametricCPD(concepts=[concept], parametrization=initialized_layer)
168187
else:
169-
self.parametric_cpds[concept] = parametric_cpd
188+
new_parametrization = parametric_cpd
189+
190+
self.parametric_cpds[concept] = new_parametrization
170191

171192
# ---- Parent resolution (unchanged) ----
172193
for var in self.variables:

0 commit comments

Comments
 (0)