Skip to content

Commit ce7f6a3

Browse files
committed
fix bug when checking prefix in the exogenous variables in graphModel
1 parent 38e2f01 commit ce7f6a3

File tree

5 files changed

+776
-10
lines changed

5 files changed

+776
-10
lines changed

conceptarium/conf/model/cbm.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ defaults:
22
- _commons
33
- _self_
44

5-
# default is joint training
5+
# wrapper for 'joint' training mode
66
_target_: "torch_concepts.nn.ConceptBottleneckModel"
77

88
task_names: ${dataset.default_task_names}
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
"""
2+
Tests for exogenous variable prefix matching bug fix.
3+
4+
This test module verifies that exogenous variables are correctly matched to their
5+
corresponding concepts using exact prefix matching, avoiding substring matching bugs.
6+
7+
Bug context: Previously, using substring matching like `"OtherCar" in "exog_OtherCarCost_state_0"`
8+
would incorrectly match, causing concepts to receive exogenous variables from other concepts
9+
with similar names.
10+
11+
Fix: Use exact prefix matching with `startswith(f"exog_{label_name}_state_")` to ensure
12+
concepts only receive their own exogenous variables.
13+
"""
14+
import unittest
15+
import torch
16+
from torch_concepts.annotations import Annotations, AxisAnnotation
17+
from torch_concepts.nn import BipartiteModel, LinearCC
18+
from torch_concepts.nn import LazyConstructor
19+
from torch_concepts.nn.modules.low.encoders.exogenous import LinearZU
20+
from torch.distributions import Bernoulli, Categorical
21+
22+
23+
class TestExogenousPrefixMatching(unittest.TestCase):
24+
"""Test exact prefix matching for exogenous variables."""
25+
26+
def test_substring_overlap_concepts(self):
27+
"""Test concepts with substring overlap don't cross-assign exogenous variables.
28+
29+
This is the core bug fix test: concepts like 'Car' and 'CarCost' should not
30+
have their exogenous variables mixed up due to substring matching.
31+
"""
32+
# Create concepts where one name is a substring of another
33+
concept_names = ['Car', 'CarCost', 'Driver', 'Task']
34+
35+
# Create annotations with different cardinalities to make exogenous counts distinct
36+
metadata = {
37+
'Car': {'distribution': Categorical, 'type': 'discrete'},
38+
'CarCost': {'distribution': Categorical, 'type': 'discrete'},
39+
'Driver': {'distribution': Categorical, 'type': 'discrete'},
40+
'Task': {'distribution': Bernoulli, 'type': 'discrete'}
41+
}
42+
cardinalities = (2, 4, 3, 1)
43+
44+
annotations = Annotations({
45+
1: AxisAnnotation(
46+
labels=tuple(concept_names),
47+
cardinalities=cardinalities,
48+
metadata=metadata
49+
)
50+
})
51+
52+
# Create bipartite model with source_exogenous
53+
model = BipartiteModel(
54+
task_names=['Task'],
55+
input_size=100,
56+
annotations=annotations,
57+
encoder=LazyConstructor(torch.nn.Linear),
58+
predictor=LazyConstructor(LinearCC),
59+
source_exogenous=LazyConstructor(LinearZU, exogenous_size=16),
60+
use_source_exogenous=True
61+
)
62+
63+
# Check that variables were created with correct parent counts
64+
car_vars = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'Car']
65+
carcost_vars = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'CarCost']
66+
driver_vars = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'Driver']
67+
68+
self.assertEqual(len(car_vars), 1)
69+
self.assertEqual(len(carcost_vars), 1)
70+
self.assertEqual(len(driver_vars), 1)
71+
72+
car_var = car_vars[0]
73+
carcost_var = carcost_vars[0]
74+
driver_var = driver_vars[0]
75+
76+
# Check that each concept has the correct number of parent variables
77+
# With source_exogenous, each concept should have exogenous variables matching its cardinality
78+
self.assertEqual(len(car_var.parents), 2,
79+
f"Car should have 2 exogenous parent variables, got {len(car_var.parents)}")
80+
self.assertEqual(len(carcost_var.parents), 4,
81+
f"CarCost should have 4 exogenous parent variables, got {len(carcost_var.parents)}")
82+
self.assertEqual(len(driver_var.parents), 3,
83+
f"Driver should have 3 exogenous parent variables, got {len(driver_var.parents)}")
84+
85+
# Verify parent names start with correct prefix (not substrings of other concepts)
86+
car_parent_names = [p if isinstance(p, str) else p.concepts[0] for p in car_var.parents]
87+
for name in car_parent_names:
88+
self.assertTrue(name.startswith('exog_Car_state_'),
89+
f"Car parent {name} should start with 'exog_Car_state_'")
90+
self.assertFalse(name.startswith('exog_CarCost_state_'),
91+
f"Car should not have CarCost exogenous variable: {name}")
92+
93+
carcost_parent_names = [p if isinstance(p, str) else p.concepts[0] for p in carcost_var.parents]
94+
for name in carcost_parent_names:
95+
self.assertTrue(name.startswith('exog_CarCost_state_'),
96+
f"CarCost parent {name} should start with 'exog_CarCost_state_'")
97+
98+
def test_exact_prefix_matching_with_similar_names(self):
99+
"""Test exact prefix matching with highly similar concept names.
100+
101+
Tests edge cases like 'A', 'AB', 'ABC' to ensure no cross-contamination.
102+
"""
103+
concept_names = ['A', 'AB', 'ABC', 'Task']
104+
105+
metadata = {
106+
'A': {'distribution': Categorical, 'type': 'discrete'},
107+
'AB': {'distribution': Categorical, 'type': 'discrete'},
108+
'ABC': {'distribution': Categorical, 'type': 'discrete'},
109+
'Task': {'distribution': Bernoulli, 'type': 'discrete'}
110+
}
111+
cardinalities = (2, 3, 4, 1)
112+
113+
annotations = Annotations({
114+
1: AxisAnnotation(
115+
labels=tuple(concept_names),
116+
cardinalities=cardinalities,
117+
metadata=metadata
118+
)
119+
})
120+
121+
model = BipartiteModel(
122+
task_names=['Task'],
123+
input_size=50,
124+
annotations=annotations,
125+
encoder=LazyConstructor(torch.nn.Linear),
126+
predictor=LazyConstructor(LinearCC),
127+
source_exogenous=LazyConstructor(LinearZU, exogenous_size=16),
128+
use_source_exogenous=True
129+
)
130+
131+
# Check each concept has only its own exogenous variables
132+
a_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'A'][0]
133+
ab_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'AB'][0]
134+
abc_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'ABC'][0]
135+
136+
self.assertEqual(len(a_var.parents), 2, "A should have 2 exogenous variables")
137+
self.assertEqual(len(ab_var.parents), 3, "AB should have 3 exogenous variables")
138+
self.assertEqual(len(abc_var.parents), 4, "ABC should have 4 exogenous variables")
139+
140+
# Verify exact prefix matching - A should not get AB or ABC variables
141+
a_parent_names = [p if isinstance(p, str) else p.concepts[0] for p in a_var.parents]
142+
for name in a_parent_names:
143+
self.assertTrue(name.startswith('exog_A_state_'),
144+
f"A parent should start with 'exog_A_state_', got {name}")
145+
# Make sure it's not 'exog_AB_state_' or 'exog_ABC_state_'
146+
self.assertFalse('exog_AB' in name or 'exog_ABC' in name,
147+
f"A should not have AB/ABC exogenous: {name}")
148+
149+
def test_underscore_in_concept_names(self):
150+
"""Test that underscores in concept names don't cause matching issues.
151+
152+
Ensures that the '_state_' suffix in exogenous variable names is correctly
153+
used as part of the matching logic.
154+
"""
155+
concept_names = ['Age_Group', 'Age_Group_Risk', 'Task']
156+
157+
metadata = {
158+
'Age_Group': {'distribution': Categorical, 'type': 'discrete'},
159+
'Age_Group_Risk': {'distribution': Categorical, 'type': 'discrete'},
160+
'Task': {'distribution': Bernoulli, 'type': 'discrete'}
161+
}
162+
cardinalities = (3, 5, 1)
163+
164+
annotations = Annotations({
165+
1: AxisAnnotation(
166+
labels=tuple(concept_names),
167+
cardinalities=cardinalities,
168+
metadata=metadata
169+
)
170+
})
171+
172+
model = BipartiteModel(
173+
task_names=['Task'],
174+
input_size=60,
175+
annotations=annotations,
176+
encoder=LazyConstructor(torch.nn.Linear),
177+
predictor=LazyConstructor(LinearCC),
178+
source_exogenous=LazyConstructor(LinearZU, exogenous_size=16),
179+
use_source_exogenous=True
180+
)
181+
182+
age_group_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'Age_Group'][0]
183+
age_group_risk_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'Age_Group_Risk'][0]
184+
185+
self.assertEqual(len(age_group_var.parents), 3,
186+
"Age_Group should have 3 exogenous variables")
187+
self.assertEqual(len(age_group_risk_var.parents), 5,
188+
"Age_Group_Risk should have 5 exogenous variables")
189+
190+
# Verify Age_Group doesn't get Age_Group_Risk's exogenous variables
191+
age_group_parent_names = [p if isinstance(p, str) else p.concepts[0] for p in age_group_var.parents]
192+
for name in age_group_parent_names:
193+
self.assertTrue(name.startswith('exog_Age_Group_state_'),
194+
f"Age_Group parent should start with 'exog_Age_Group_state_', got {name}")
195+
self.assertFalse(name.startswith('exog_Age_Group_Risk_state_'),
196+
f"Age_Group should not have Age_Group_Risk exogenous: {name}")
197+
198+
def test_predictor_exogenous_filtering(self):
199+
"""Test that predictor correctly filters exogenous variables for parent concepts.
200+
201+
The predictor should only receive exogenous variables from its actual parents,
202+
not from concepts with similar names.
203+
"""
204+
concept_names = ['Other', 'OtherCar', 'OtherCarCost', 'Task']
205+
206+
metadata = {
207+
'Other': {'distribution': Categorical, 'type': 'discrete'},
208+
'OtherCar': {'distribution': Categorical, 'type': 'discrete'},
209+
'OtherCarCost': {'distribution': Categorical, 'type': 'discrete'},
210+
'Task': {'distribution': Categorical, 'type': 'discrete'}
211+
}
212+
cardinalities = (2, 3, 4, 2)
213+
214+
annotations = Annotations({
215+
1: AxisAnnotation(
216+
labels=tuple(concept_names),
217+
cardinalities=cardinalities,
218+
metadata=metadata
219+
)
220+
})
221+
222+
model = BipartiteModel(
223+
task_names=['Task'],
224+
input_size=70,
225+
annotations=annotations,
226+
encoder=LazyConstructor(torch.nn.Linear),
227+
predictor=LazyConstructor(LinearCC),
228+
source_exogenous=LazyConstructor(LinearZU, exogenous_size=16),
229+
use_source_exogenous=True
230+
)
231+
232+
# Check that root concepts have correct exogenous parents
233+
other_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'Other'][0]
234+
othercar_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'OtherCar'][0]
235+
othercarcost_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'OtherCarCost'][0]
236+
237+
self.assertEqual(len(other_var.parents), 2,
238+
"Other should have 2 exogenous variables")
239+
self.assertEqual(len(othercar_var.parents), 3,
240+
"OtherCar should have 3 exogenous variables")
241+
self.assertEqual(len(othercarcost_var.parents), 4,
242+
"OtherCarCost should have 4 exogenous variables")
243+
244+
# Verify OtherCar doesn't get OtherCarCost's exogenous (the original bug!)
245+
othercar_parent_names = [p if isinstance(p, str) else p.concepts[0] for p in othercar_var.parents]
246+
for name in othercar_parent_names:
247+
self.assertTrue(name.startswith('exog_OtherCar_state_'),
248+
f"OtherCar parent should start with 'exog_OtherCar_state_', got {name}")
249+
self.assertFalse(name.startswith('exog_OtherCarCost_state_'),
250+
f"OtherCar should NOT have OtherCarCost exogenous: {name}")
251+
252+
def test_no_exogenous_without_source_exogenous_flag(self):
253+
"""Test that exogenous variables are not created when use_source_exogenous=False.
254+
255+
This is a control test to ensure the exogenous feature is opt-in.
256+
"""
257+
concept_names = ['Car', 'CarCost', 'Task']
258+
259+
metadata = {
260+
'Car': {'distribution': Categorical, 'type': 'discrete'},
261+
'CarCost': {'distribution': Categorical, 'type': 'discrete'},
262+
'Task': {'distribution': Bernoulli, 'type': 'discrete'}
263+
}
264+
cardinalities = (2, 4, 1)
265+
266+
annotations = Annotations({
267+
1: AxisAnnotation(
268+
labels=tuple(concept_names),
269+
cardinalities=cardinalities,
270+
metadata=metadata
271+
)
272+
})
273+
274+
# use_source_exogenous=False (default)
275+
model = BipartiteModel(
276+
task_names=['Task'],
277+
input_size=80,
278+
annotations=annotations,
279+
encoder=LazyConstructor(torch.nn.Linear),
280+
predictor=LazyConstructor(LinearCC),
281+
use_source_exogenous=False
282+
)
283+
284+
# Encoders should not have exogenous parents when source_exogenous=False
285+
car_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'Car'][0]
286+
carcost_var = [v for v in model.probabilistic_model.variables if v.concepts[0] == 'CarCost'][0]
287+
288+
# Without source exogenous, root concepts should only have 'input' as parent, no exogenous variables
289+
self.assertEqual(len(car_var.parents), 1,
290+
"Car should have 1 parent (input) when use_source_exogenous=False")
291+
self.assertEqual(type(car_var.parents[0]).__name__, 'InputVariable',
292+
"Car's only parent should be InputVariable")
293+
294+
# Verify no exogenous variables exist
295+
exog_vars = [v for v in model.probabilistic_model.variables if hasattr(v, 'name') and v.name.startswith('exog_')]
296+
self.assertEqual(len(exog_vars), 0,
297+
"No exogenous variables should exist when use_source_exogenous=False")
298+
if __name__ == '__main__':
299+
unittest.main()

0 commit comments

Comments
 (0)