Skip to content

Commit d3634ae

Browse files
Remove lazy build from graph model
1 parent 268599f commit d3634ae

File tree

4 files changed

+17
-26
lines changed

4 files changed

+17
-26
lines changed

examples/utilization/1_pgm/0_concept_bottleneck_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main():
3232
# ParametricCPD setup
3333
backbone = ParametricCPD("input", parametrization=torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], latent_dims), torch.nn.LeakyReLU()))
3434
c_encoder = ParametricCPD(["c1", "c2"], parametrization=LazyConstructor(LinearZC))
35-
y_predictor = ParametricCPD("xor", parametrization=LazyConstructor(LinearCC))
35+
y_predictor = ParametricCPD("xor", parametrization=LinearCC(in_features_endogenous=2, out_features=2))
3636

3737
# ProbabilisticModel Initialization
3838
concept_model = ProbabilisticModel(variables=[input_var, *concepts, tasks], parametric_cpds=[backbone, *c_encoder, y_predictor])

examples/utilization/2_model/0_concept_bottleneck_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def main():
3838
concept_model = BipartiteModel(task_names,
3939
latent_dims,
4040
annotations,
41-
LazyConstructor(LinearZC),
42-
LazyConstructor(LinearCC))
41+
LinearZC(10, 1),
42+
LinearCC(2, 2))
4343

4444
# Inference Initialization
4545
inference_engine = DeterministicInference(concept_model.probabilistic_model)

torch_concepts/nn/modules/mid/constructors/bipartite.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pandas as pd
44
import torch
5+
from torch.nn import Module
56

67
from .....annotations import Annotations
78
from .concept_graph import ConceptGraph
@@ -35,7 +36,7 @@ class BipartiteModel(GraphModel):
3536
Example:
3637
>>> import torch
3738
>>> from torch_concepts import Annotations, AxisAnnotation
38-
>>> from torch_concepts.nn import BipartiteModel, LazyConstructor
39+
>>> from torch_concepts.nn import BipartiteModel, LazyConstructor, LinearCC
3940
>>> from torch.distributions import Bernoulli
4041
>>>
4142
>>> # Define concepts and tasks
@@ -78,11 +79,11 @@ def __init__(
7879
task_names: Union[List[str], str],
7980
input_size: int,
8081
annotations: Annotations,
81-
encoder: LazyConstructor,
82-
predictor: LazyConstructor,
82+
encoder: Union[LazyConstructor, Module],
83+
predictor: Union[LazyConstructor, Module],
8384
use_source_exogenous: bool = None,
84-
source_exogenous: Optional[LazyConstructor] = None,
85-
internal_exogenous: Optional[LazyConstructor] = None,
85+
source_exogenous: Optional[Union[LazyConstructor, Module]] = None,
86+
internal_exogenous: Optional[Union[LazyConstructor, Module]] = None,
8687
):
8788
task_names = ensure_list(task_names)
8889
# get label names

torch_concepts/nn/modules/mid/constructors/graph.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import List, Tuple, Optional
2-
from torch.nn import Identity
1+
from typing import List, Tuple, Optional, Union
2+
from torch.nn import Identity, Module
33

44
from .....annotations import Annotations
55
from ..models.variable import Variable, InputVariable, ExogenousVariable, EndogenousVariable
@@ -50,7 +50,7 @@ class GraphModel(BaseConstructor):
5050
>>> import torch
5151
>>> import pandas as pd
5252
>>> from torch_concepts import Annotations, AxisAnnotation, ConceptGraph
53-
>>> from torch_concepts.nn import GraphModel, LazyConstructor
53+
>>> from torch_concepts.nn import GraphModel, LazyConstructor, LinearCC
5454
>>> from torch.distributions import Bernoulli
5555
>>>
5656
>>> # Define concepts and their structure
@@ -111,11 +111,11 @@ def __init__(self,
111111
model_graph: ConceptGraph,
112112
input_size: int,
113113
annotations: Annotations,
114-
encoder: LazyConstructor,
115-
predictor: LazyConstructor,
114+
encoder: Union[LazyConstructor, Module],
115+
predictor: Union[LazyConstructor, Module],
116116
use_source_exogenous: bool = None,
117-
source_exogenous: Optional[LazyConstructor] = None,
118-
internal_exogenous: Optional[LazyConstructor] = None
117+
source_exogenous: Optional[Union[LazyConstructor, Module]] = None,
118+
internal_exogenous: Optional[Union[LazyConstructor, Module]] = None
119119
):
120120
super(GraphModel, self).__init__(
121121
input_size=input_size,
@@ -281,17 +281,7 @@ def _init_predictors(self,
281281
parents=endogenous_parents_names+exog_vars_names,
282282
distribution=self.annotations[1].metadata[c_name]['distribution'],
283283
size=self.annotations[1].cardinalities[self.annotations[1].get_index(c_name)])
284-
285-
# TODO: we currently assume predictors can use exogenous vars if any, but not latent
286-
lazy_constructor = layer.build(
287-
in_features_endogenous=in_features_endogenous,
288-
in_features_exogenous=in_features_exogenous,
289-
in_features=None,
290-
out_features=predictor_var.size,
291-
cardinalities=[predictor_var.size]
292-
)
293-
294-
predictor_cpd = ParametricCPD(c_name, parametrization=lazy_constructor)
284+
predictor_cpd = ParametricCPD(c_name, parametrization=layer)
295285

296286
predictor_vars.append(predictor_var)
297287
predictor_cpds.append(predictor_cpd)

0 commit comments

Comments
 (0)