diff --git a/tests/test_torch_device_utils.py b/tests/test_torch_device_utils.py new file mode 100644 index 0000000..4224280 --- /dev/null +++ b/tests/test_torch_device_utils.py @@ -0,0 +1,57 @@ +import sys +import types + + +def _make_fake_torch(*, cuda_available: bool): + torch = types.ModuleType("torch") + torch.__path__ = [] # allow importing torch.backends + + class _FakeCUDA: + @staticmethod + def is_available(): + return cuda_available + + torch.cuda = _FakeCUDA() + torch.cpu = object() + + backends = types.ModuleType("torch.backends") + backends.__path__ = [] + torch.backends = backends + + return torch, backends + + +def test_get_installed_torch_platform_prefers_cuda_over_directml(monkeypatch): + # Regression test for Windows environments where users have both a working CUDA/ROCm torch + # backend AND torch-directml installed. In that scenario we should prefer torch.cuda over + # DirectML to avoid mis-detecting the active backend. + fake_torch, fake_backends = _make_fake_torch(cuda_available=True) + + fake_torch_directml = types.ModuleType("torch_directml") + fake_torch_directml.is_available = lambda: True + + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "torch.backends", fake_backends) + monkeypatch.setitem(sys.modules, "torch_directml", fake_torch_directml) + + from torchruntime.utils.torch_device_utils import get_installed_torch_platform + + torch_platform_name, torch_platform = get_installed_torch_platform() + assert torch_platform_name == "cuda" + assert torch_platform is fake_torch.cuda + + +def test_get_installed_torch_platform_uses_directml_when_cuda_unavailable(monkeypatch): + fake_torch, fake_backends = _make_fake_torch(cuda_available=False) + + fake_torch_directml = types.ModuleType("torch_directml") + fake_torch_directml.is_available = lambda: True + + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "torch.backends", fake_backends) + monkeypatch.setitem(sys.modules, "torch_directml", fake_torch_directml) + + from torchruntime.utils.torch_device_utils import get_installed_torch_platform + + torch_platform_name, _ = get_installed_torch_platform() + assert torch_platform_name == "directml" diff --git a/torchruntime/utils/torch_device_utils.py b/torchruntime/utils/torch_device_utils.py index 25f8c25..64c14d9 100644 --- a/torchruntime/utils/torch_device_utils.py +++ b/torchruntime/utils/torch_device_utils.py @@ -42,13 +42,15 @@ def get_installed_torch_platform(): import torch.backends from platform import system as os_name - if _is_directml_platform_available(): - return DIRECTML, torch.directml - if torch.cuda.is_available(): return CUDA, torch.cuda if hasattr(torch, XPU) and torch.xpu.is_available(): return XPU, torch.xpu + + # DirectML is a useful fallback on Windows, but users can have torch-directml installed + # alongside a working CUDA/ROCm torch build. Prefer the native torch backend when available. + if _is_directml_platform_available(): + return DIRECTML, torch.directml if os_name() == "Darwin": if hasattr(torch, MPS): return MPS, torch.mps