From 1efd94258316d81900132ca17fefb65461f856c1 Mon Sep 17 00:00:00 2001 From: SBNovaScript Date: Tue, 1 Apr 2025 20:38:47 -0400 Subject: [PATCH 1/2] fix(task): model can sometimes be None --- src/pyannote/audio/pipelines/utils/getter.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/pyannote/audio/pipelines/utils/getter.py b/src/pyannote/audio/pipelines/utils/getter.py index 9e69a321c..12f8ac37a 100644 --- a/src/pyannote/audio/pipelines/utils/getter.py +++ b/src/pyannote/audio/pipelines/utils/getter.py @@ -91,6 +91,13 @@ def get_model( f"Unsupported type ({type(model)}) for loading model: " f"expected `str` or `dict`." ) + + if not hasattr(model, 'eval') or not callable(model.eval): + raise ValueError( + "The model could not be loaded. " + "Please check the checkpoint path or the model name. " + f"Recieved: {model}" + ) model.eval() return model From e69ffd20520ba0cd0946b5a9dca38e8c3e24de99 Mon Sep 17 00:00:00 2001 From: SBNovaScript Date: Wed, 2 Apr 2025 22:36:02 -0400 Subject: [PATCH 2/2] Add test for getter functionality nit --- tests/utils/test_getter.py | 95 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/utils/test_getter.py diff --git a/tests/utils/test_getter.py b/tests/utils/test_getter.py new file mode 100644 index 000000000..7d7b45322 --- /dev/null +++ b/tests/utils/test_getter.py @@ -0,0 +1,95 @@ +import pytest +import torch +from unittest.mock import patch, MagicMock + +from pyannote.audio.core.model import Model +from pyannote.audio.pipelines.utils.getter import get_model + + +class BrokenModelWithoutEval(Model): + def __init__(self): + super().__init__() + self.eval = None + + def forward(self, waveforms): + return torch.rand(1) + + +class BrokenModelWithNonCallableEval(Model): + def __init__(self): + super().__init__() + self.eval = "not_callable" + + def forward(self, waveforms): + return torch.rand(1) + + +def test_model_without_eval_attribute(): + model = BrokenModelWithoutEval() + + with patch('pyannote.audio.pipelines.utils.getter.hasattr', return_value=False): + with pytest.raises(ValueError) as excinfo: + get_model(model) + + assert "The model could not be loaded" in str(excinfo.value) + assert f"Recieved: {model}" in str(excinfo.value) + + +def test_model_with_non_callable_eval(): + model = BrokenModelWithNonCallableEval() + + with pytest.raises(ValueError) as excinfo: + get_model(model) + + assert "The model could not be loaded" in str(excinfo.value) + assert f"Recieved: {model}" in str(excinfo.value) + + +@patch('pyannote.audio.core.model.Model.from_pretrained') +def test_get_model_with_auth_token(mock_from_pretrained): + mock_model = MagicMock() + mock_model.eval = MagicMock(return_value=mock_model) + mock_from_pretrained.return_value = mock_model + + model_path = "dummy/model/path" + auth_token = "test_token" + result = get_model(model_path, use_auth_token=auth_token) + + mock_from_pretrained.assert_called_once_with( + model_path, use_auth_token=auth_token, strict=False + ) + + mock_model.eval.assert_called_once() + + assert result == mock_model + + +@patch('pyannote.audio.core.model.Model.from_pretrained') +def test_get_model_with_dict_config(mock_from_pretrained): + mock_model = MagicMock() + mock_model.eval = MagicMock(return_value=mock_model) + mock_from_pretrained.return_value = mock_model + + model_config = { + "checkpoint": "dummy/model/path", + "map_location": "cuda:0" + } + auth_token = "test_token" + result = get_model(model_config, use_auth_token=auth_token) + + expected_config = model_config.copy() + expected_config["use_auth_token"] = auth_token + + mock_from_pretrained.assert_called_once_with(**expected_config) + + mock_model.eval.assert_called_once() + + assert result == mock_model + + +def test_get_model_with_invalid_type(): + with pytest.raises(TypeError) as excinfo: + get_model(42) + + assert "Unsupported type" in str(excinfo.value) + assert "expected `str` or `dict`" in str(excinfo.value)