-
Notifications
You must be signed in to change notification settings - Fork 953
fix(task): model can sometimes be None #1855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,95 @@ | ||||||
import pytest | ||||||
import torch | ||||||
from unittest.mock import patch, MagicMock | ||||||
|
||||||
from pyannote.audio.core.model import Model | ||||||
from pyannote.audio.pipelines.utils.getter import get_model | ||||||
|
||||||
|
||||||
class BrokenModelWithoutEval(Model): | ||||||
def __init__(self): | ||||||
super().__init__() | ||||||
self.eval = None | ||||||
|
||||||
def forward(self, waveforms): | ||||||
return torch.rand(1) | ||||||
|
||||||
|
||||||
class BrokenModelWithNonCallableEval(Model): | ||||||
def __init__(self): | ||||||
super().__init__() | ||||||
self.eval = "not_callable" | ||||||
|
||||||
def forward(self, waveforms): | ||||||
return torch.rand(1) | ||||||
|
||||||
|
||||||
def test_model_without_eval_attribute(): | ||||||
model = BrokenModelWithoutEval() | ||||||
|
||||||
with patch('pyannote.audio.pipelines.utils.getter.hasattr', return_value=False): | ||||||
with pytest.raises(ValueError) as excinfo: | ||||||
get_model(model) | ||||||
|
||||||
assert "The model could not be loaded" in str(excinfo.value) | ||||||
assert f"Recieved: {model}" in str(excinfo.value) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a spelling error in the test assertion. 'Recieved' should be 'Received' to match the corrected error message.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
|
||||||
|
||||||
def test_model_with_non_callable_eval(): | ||||||
model = BrokenModelWithNonCallableEval() | ||||||
|
||||||
with pytest.raises(ValueError) as excinfo: | ||||||
get_model(model) | ||||||
|
||||||
assert "The model could not be loaded" in str(excinfo.value) | ||||||
assert f"Recieved: {model}" in str(excinfo.value) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a spelling error in the test assertion. 'Recieved' should be 'Received' to match the corrected error message.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
|
||||||
|
||||||
@patch('pyannote.audio.core.model.Model.from_pretrained') | ||||||
def test_get_model_with_auth_token(mock_from_pretrained): | ||||||
mock_model = MagicMock() | ||||||
mock_model.eval = MagicMock(return_value=mock_model) | ||||||
mock_from_pretrained.return_value = mock_model | ||||||
|
||||||
model_path = "dummy/model/path" | ||||||
auth_token = "test_token" | ||||||
result = get_model(model_path, use_auth_token=auth_token) | ||||||
|
||||||
mock_from_pretrained.assert_called_once_with( | ||||||
model_path, use_auth_token=auth_token, strict=False | ||||||
) | ||||||
|
||||||
mock_model.eval.assert_called_once() | ||||||
|
||||||
assert result == mock_model | ||||||
|
||||||
|
||||||
@patch('pyannote.audio.core.model.Model.from_pretrained') | ||||||
def test_get_model_with_dict_config(mock_from_pretrained): | ||||||
mock_model = MagicMock() | ||||||
mock_model.eval = MagicMock(return_value=mock_model) | ||||||
mock_from_pretrained.return_value = mock_model | ||||||
|
||||||
model_config = { | ||||||
"checkpoint": "dummy/model/path", | ||||||
"map_location": "cuda:0" | ||||||
} | ||||||
auth_token = "test_token" | ||||||
result = get_model(model_config, use_auth_token=auth_token) | ||||||
|
||||||
expected_config = model_config.copy() | ||||||
expected_config["use_auth_token"] = auth_token | ||||||
|
||||||
mock_from_pretrained.assert_called_once_with(**expected_config) | ||||||
|
||||||
mock_model.eval.assert_called_once() | ||||||
|
||||||
assert result == mock_model | ||||||
|
||||||
|
||||||
def test_get_model_with_invalid_type(): | ||||||
with pytest.raises(TypeError) as excinfo: | ||||||
get_model(42) | ||||||
|
||||||
assert "Unsupported type" in str(excinfo.value) | ||||||
assert "expected `str` or `dict`" in str(excinfo.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a spelling error in the error message. 'Recieved' should be 'Received'.
Copilot uses AI. Check for mistakes.