Skip to content

Commit e774718

Browse files
authored
Fix test_trompt.py (#373)
`stype_encoder_dicts` needs to be initialized inside `test_trompt()`. otherwise, it will be persisted across different parametrization, causing an unexpected error.
1 parent 56f687d commit e774718

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

test/nn/models/test_trompt.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,11 @@
77

88

99
@pytest.mark.parametrize('batch_size', [0, 5])
10-
@pytest.mark.parametrize('stype_encoder_dicts', [
11-
[
12-
{
13-
stype.numerical: LinearEncoder(),
14-
stype.categorical: EmbeddingEncoder(),
15-
},
16-
{
17-
stype.numerical: LinearEncoder(),
18-
stype.categorical: EmbeddingEncoder(),
19-
},
20-
{
21-
stype.numerical: LinearEncoder(),
22-
stype.categorical: EmbeddingEncoder(),
23-
},
24-
],
25-
None,
10+
@pytest.mark.parametrize('use_stype_encoder_dicts', [
11+
True,
12+
False,
2613
])
27-
def test_trompt(batch_size, stype_encoder_dicts):
14+
def test_trompt(batch_size, use_stype_encoder_dicts):
2815
batch_size = 10
2916
channels = 8
3017
out_channels = 1
@@ -33,6 +20,23 @@ def test_trompt(batch_size, stype_encoder_dicts):
3320
dataset: Dataset = FakeDataset(num_rows=10, with_nan=False)
3421
dataset.materialize()
3522
tensor_frame = dataset.tensor_frame[:batch_size]
23+
if use_stype_encoder_dicts:
24+
stype_encoder_dicts = [
25+
{
26+
stype.numerical: LinearEncoder(),
27+
stype.categorical: EmbeddingEncoder(),
28+
},
29+
{
30+
stype.numerical: LinearEncoder(),
31+
stype.categorical: EmbeddingEncoder(),
32+
},
33+
{
34+
stype.numerical: LinearEncoder(),
35+
stype.categorical: EmbeddingEncoder(),
36+
},
37+
]
38+
else:
39+
stype_encoder_dicts = None
3640
model = Trompt(
3741
channels=channels,
3842
out_channels=out_channels,

0 commit comments

Comments
 (0)