|
1 | | -from typing import List, Tuple, Optional |
2 | | -from torch.nn import Identity |
| 1 | +from typing import List, Tuple, Optional, Union |
| 2 | +from torch.nn import Identity, Module |
3 | 3 |
|
4 | 4 | from .....annotations import Annotations |
5 | 5 | from ..models.variable import Variable, InputVariable, ExogenousVariable, EndogenousVariable |
@@ -50,7 +50,7 @@ class GraphModel(BaseConstructor): |
50 | 50 | >>> import torch |
51 | 51 | >>> import pandas as pd |
52 | 52 | >>> from torch_concepts import Annotations, AxisAnnotation, ConceptGraph |
53 | | - >>> from torch_concepts.nn import GraphModel, LazyConstructor |
| 53 | + >>> from torch_concepts.nn import GraphModel, LazyConstructor, LinearCC |
54 | 54 | >>> from torch.distributions import Bernoulli |
55 | 55 | >>> |
56 | 56 | >>> # Define concepts and their structure |
@@ -111,11 +111,11 @@ def __init__(self, |
111 | 111 | model_graph: ConceptGraph, |
112 | 112 | input_size: int, |
113 | 113 | annotations: Annotations, |
114 | | - encoder: LazyConstructor, |
115 | | - predictor: LazyConstructor, |
| 114 | + encoder: Union[LazyConstructor, Module], |
| 115 | + predictor: Union[LazyConstructor, Module], |
116 | 116 | use_source_exogenous: bool = None, |
117 | | - source_exogenous: Optional[LazyConstructor] = None, |
118 | | - internal_exogenous: Optional[LazyConstructor] = None |
| 117 | + source_exogenous: Optional[Union[LazyConstructor, Module]] = None, |
| 118 | + internal_exogenous: Optional[Union[LazyConstructor, Module]] = None |
119 | 119 | ): |
120 | 120 | super(GraphModel, self).__init__( |
121 | 121 | input_size=input_size, |
@@ -281,17 +281,7 @@ def _init_predictors(self, |
281 | 281 | parents=endogenous_parents_names+exog_vars_names, |
282 | 282 | distribution=self.annotations[1].metadata[c_name]['distribution'], |
283 | 283 | size=self.annotations[1].cardinalities[self.annotations[1].get_index(c_name)]) |
284 | | - |
285 | | - # TODO: we currently assume predictors can use exogenous vars if any, but not latent |
286 | | - lazy_constructor = layer.build( |
287 | | - in_features_endogenous=in_features_endogenous, |
288 | | - in_features_exogenous=in_features_exogenous, |
289 | | - in_features=None, |
290 | | - out_features=predictor_var.size, |
291 | | - cardinalities=[predictor_var.size] |
292 | | - ) |
293 | | - |
294 | | - predictor_cpd = ParametricCPD(c_name, parametrization=lazy_constructor) |
| 284 | + predictor_cpd = ParametricCPD(c_name, parametrization=layer) |
295 | 285 |
|
296 | 286 | predictor_vars.append(predictor_var) |
297 | 287 | predictor_cpds.append(predictor_cpd) |
|
0 commit comments