diff --git a/tests/test_platform_detection.py b/tests/test_platform_detection.py index d26c107..11de967 100644 --- a/tests/test_platform_detection.py +++ b/tests/test_platform_detection.py @@ -121,6 +121,23 @@ def test_nvidia_gpu_linux(monkeypatch): assert get_torch_platform(gpu_infos) == expected +def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows") + monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64") + monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 11)) + monkeypatch.setattr("torchruntime.platform_detection.get_nvidia_arch", lambda device_names: 8.6) + + gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)] + + assert get_torch_platform(gpu_infos) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124" + assert get_torch_platform(gpu_infos, packages=["torch<2.7.0"]) == "cu124" + assert get_torch_platform(gpu_infos, packages=["torch<=2.7.0"]) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torch!=2.7.0"]) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torch>=2.7.0,!=2.7.0,!=2.7.1,<2.8.0"]) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torchvision==0.21.0"]) == "cu124" + + def test_nvidia_gpu_mac(monkeypatch): monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin") monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64") diff --git a/torchruntime/installer.py b/torchruntime/installer.py index 475244f..d7f17d3 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -98,7 +98,7 @@ def install(packages=[], use_uv=False): """ gpu_infos = get_gpus() - torch_platform = get_torch_platform(gpu_infos) + torch_platform = get_torch_platform(gpu_infos, packages=packages) cmds = get_install_commands(torch_platform, packages) cmds = get_pip_commands(cmds, use_uv=use_uv) run_commands(cmds) diff --git a/torchruntime/platform_detection.py b/torchruntime/platform_detection.py index 209e0e2..f5b785f 100644 --- a/torchruntime/platform_detection.py +++ b/torchruntime/platform_detection.py @@ -1,7 +1,9 @@ -import re import sys import platform +from packaging.requirements import Requirement +from packaging.version import Version + from .gpu_db import get_nvidia_arch, get_amd_gfx_info from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK @@ -9,13 +11,53 @@ arch = platform.machine().lower() py_version = sys.version_info +_CUDA_12_8_MIN_VERSIONS = { + "torch": Version("2.7.0"), + "torchaudio": Version("2.7.0"), + "torchvision": Version("0.22.0"), +} + + +def _packages_require_cuda_12_4(packages): + if not packages: + return False + + for package in packages: + try: + requirement = Requirement(package) + except Exception: + continue + + name = requirement.name.lower().replace("_", "-") + threshold = _CUDA_12_8_MIN_VERSIONS.get(name) + if not threshold or not requirement.specifier: + continue + + test_versions = [ + threshold, + Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 1}"), + Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 2}"), + Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 3}"), + Version(f"{threshold.major}.{threshold.minor + 1}.0"), + Version(f"{threshold.major + 1}.0.0"), + ] + + allows_threshold_or_higher = any( + requirement.specifier.contains(str(version), prereleases=True) for version in test_versions + ) + if not allows_threshold_or_higher: + return True + + return False -def get_torch_platform(gpu_infos): + +def get_torch_platform(gpu_infos, packages=[]): """ Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information. Args: gpu_infos (list of `torchruntime.device_db.GPU` instances) + packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings. Returns: str: A string representing the platform to use. Possible values: @@ -53,12 +95,12 @@ def get_torch_platform(gpu_infos): integrated_devices.append(device) if discrete_devices: - return _get_platform_for_discrete(discrete_devices) + return _get_platform_for_discrete(discrete_devices, packages=packages) return _get_platform_for_integrated(integrated_devices) -def _get_platform_for_discrete(gpu_infos): +def _get_platform_for_discrete(gpu_infos, packages=None): vendor_ids = set(gpu.vendor_id for gpu in gpu_infos) if len(vendor_ids) > 1: @@ -126,6 +168,9 @@ def _get_platform_for_discrete(gpu_infos): if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9): return "cu124" + if _packages_require_cuda_12_4(packages): + return "cu124" + return "cu128" elif os_name == "Darwin": raise NotImplementedError(