Skip to content

Commit e69ffd2

Browse files
committed
Add test for getter functionality
nit
1 parent 1efd942 commit e69ffd2

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

tests/utils/test_getter.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import pytest
2+
import torch
3+
from unittest.mock import patch, MagicMock
4+
5+
from pyannote.audio.core.model import Model
6+
from pyannote.audio.pipelines.utils.getter import get_model
7+
8+
9+
class BrokenModelWithoutEval(Model):
10+
def __init__(self):
11+
super().__init__()
12+
self.eval = None
13+
14+
def forward(self, waveforms):
15+
return torch.rand(1)
16+
17+
18+
class BrokenModelWithNonCallableEval(Model):
19+
def __init__(self):
20+
super().__init__()
21+
self.eval = "not_callable"
22+
23+
def forward(self, waveforms):
24+
return torch.rand(1)
25+
26+
27+
def test_model_without_eval_attribute():
28+
model = BrokenModelWithoutEval()
29+
30+
with patch('pyannote.audio.pipelines.utils.getter.hasattr', return_value=False):
31+
with pytest.raises(ValueError) as excinfo:
32+
get_model(model)
33+
34+
assert "The model could not be loaded" in str(excinfo.value)
35+
assert f"Recieved: {model}" in str(excinfo.value)
36+
37+
38+
def test_model_with_non_callable_eval():
39+
model = BrokenModelWithNonCallableEval()
40+
41+
with pytest.raises(ValueError) as excinfo:
42+
get_model(model)
43+
44+
assert "The model could not be loaded" in str(excinfo.value)
45+
assert f"Recieved: {model}" in str(excinfo.value)
46+
47+
48+
@patch('pyannote.audio.core.model.Model.from_pretrained')
49+
def test_get_model_with_auth_token(mock_from_pretrained):
50+
mock_model = MagicMock()
51+
mock_model.eval = MagicMock(return_value=mock_model)
52+
mock_from_pretrained.return_value = mock_model
53+
54+
model_path = "dummy/model/path"
55+
auth_token = "test_token"
56+
result = get_model(model_path, use_auth_token=auth_token)
57+
58+
mock_from_pretrained.assert_called_once_with(
59+
model_path, use_auth_token=auth_token, strict=False
60+
)
61+
62+
mock_model.eval.assert_called_once()
63+
64+
assert result == mock_model
65+
66+
67+
@patch('pyannote.audio.core.model.Model.from_pretrained')
68+
def test_get_model_with_dict_config(mock_from_pretrained):
69+
mock_model = MagicMock()
70+
mock_model.eval = MagicMock(return_value=mock_model)
71+
mock_from_pretrained.return_value = mock_model
72+
73+
model_config = {
74+
"checkpoint": "dummy/model/path",
75+
"map_location": "cuda:0"
76+
}
77+
auth_token = "test_token"
78+
result = get_model(model_config, use_auth_token=auth_token)
79+
80+
expected_config = model_config.copy()
81+
expected_config["use_auth_token"] = auth_token
82+
83+
mock_from_pretrained.assert_called_once_with(**expected_config)
84+
85+
mock_model.eval.assert_called_once()
86+
87+
assert result == mock_model
88+
89+
90+
def test_get_model_with_invalid_type():
91+
with pytest.raises(TypeError) as excinfo:
92+
get_model(42)
93+
94+
assert "Unsupported type" in str(excinfo.value)
95+
assert "expected `str` or `dict`" in str(excinfo.value)

0 commit comments

Comments
 (0)