diff --git a/tests/test_installer.py b/tests/test_installer.py index ade0358..578f854 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -1,7 +1,7 @@ import sys import pytest from unittest.mock import patch -from torchruntime.installer import get_install_commands, get_pip_commands, run_commands +from torchruntime.installer import get_install_commands, get_pip_commands, run_commands, install def test_empty_args(): @@ -125,3 +125,57 @@ def test_run_commands(): # Check that subprocess.run was called with the correct arguments mock_run.assert_any_call(cmds[0]) mock_run.assert_any_call(cmds[1]) + + +def test_install_promotes_cuda_platform_for_torch_27(monkeypatch): + captured = {} + + def fake_get_install_commands(torch_platform, packages): + captured["platform"] = torch_platform + return [packages] + + monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: []) + monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu124") + monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands) + monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds) + monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None) + + install(["torch==2.7.1"]) + + assert captured["platform"] == "cu128" + + +def test_install_demotes_cuda_platform_for_torch_26(monkeypatch): + captured = {} + + def fake_get_install_commands(torch_platform, packages): + captured["platform"] = torch_platform + return [packages] + + monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: []) + monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu128") + monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands) + monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds) + monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None) + + install(["torch==2.6.0"]) + + assert captured["platform"] == "cu124" + + +def test_install_promotes_cuda_platform_for_torchvision_022(monkeypatch): + captured = {} + + def fake_get_install_commands(torch_platform, packages): + captured["platform"] = torch_platform + return [packages] + + monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: []) + monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu124") + monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands) + monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds) + monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None) + + install(["torchvision==0.22.0"]) + + assert captured["platform"] == "cu128" diff --git a/torchruntime/installer.py b/torchruntime/installer.py index 475244f..0b2ac4e 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -5,6 +5,7 @@ from .consts import CONTACT_LINK from .device_db import get_gpus +from .gpu_db import get_nvidia_arch from .platform_detection import get_torch_platform os_name = platform.system() @@ -12,6 +13,11 @@ PIP_PREFIX = [sys.executable, "-m", "pip", "install"] CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") +REQ_SPEC_REGEX = re.compile( + r"^\s*(?P[A-Za-z0-9_.-]+)(?:\[[^\]]+\])?\s*(?P==|>=|<=|~=|!=|<|>)\s*(?P[^,;\s]+)" +) +MAJOR_MINOR_REGEX = re.compile(r"^(?P\d+)\.(?P\d+)") +TORCH_2_7 = (2, 7) def get_install_commands(torch_platform, packages): @@ -91,6 +97,127 @@ def run_commands(cmds): subprocess.run(cmd) +def _parse_major_minor(version: str): + match = MAJOR_MINOR_REGEX.match(version) + if not match: + return None + return int(match.group("major")), int(match.group("minor")) + + +def _is_major_minor_gte(left, right): + return left[0] > right[0] or (left[0] == right[0] and left[1] >= right[1]) + + +def _is_major_minor_lt(left, right): + return left[0] < right[0] or (left[0] == right[0] and left[1] < right[1]) + + +def _cuda_platform_has_prefix(torch_platform: str): + return torch_platform.startswith("nightly/") + + +def _cuda_platform_with_prefix(torch_platform: str, cuda_platform: str): + if _cuda_platform_has_prefix(torch_platform): + return f"nightly/{cuda_platform}" + return cuda_platform + + +def _get_cuda_platform_for_pytorch_packages(packages): + """ + Infer a CUDA platform (cu124 vs cu128) from user-specified PyTorch package versions. + + This is needed because PyTorch 2.7.x is published under cu128 wheels, and older + releases (<=2.6) are published under cu124 wheels. When the requested versions + are pinned, the installer must select the matching index URL or pip will fail + with "No matching distribution found". + + Returns: + "cu124" | "cu128" | None + """ + + if not packages: + return None + + desired_cuda = None + + for raw_req in packages: + if not raw_req: + continue + + req = str(raw_req).strip() + if not req or req.startswith("-"): + continue + + match = REQ_SPEC_REGEX.match(req) + if not match: + continue + + name = match.group("name").lower() + op = match.group("op") + version = match.group("version") + + major_minor = _parse_major_minor(version) + if not major_minor: + continue + + # Map torchvision's versioning scheme to the matching torch major/minor. + if name == "torchvision": + tv_major, tv_minor = major_minor + if tv_major != 0: + continue + torch_major_minor = (2, max(0, tv_minor - 15)) + elif name in ("torch", "torchaudio"): + torch_major_minor = major_minor + else: + continue + + required_cuda = None + if op == "==": + required_cuda = "cu128" if _is_major_minor_gte(torch_major_minor, TORCH_2_7) else "cu124" + elif op in (">=", ">", "~="): + if _is_major_minor_gte(torch_major_minor, TORCH_2_7): + required_cuda = "cu128" + elif op in ("<", "<="): + if _is_major_minor_lt(torch_major_minor, TORCH_2_7): + required_cuda = "cu124" + + if required_cuda is None: + continue + + if desired_cuda is None: + desired_cuda = required_cuda + elif desired_cuda != required_cuda: + # Conflicting version pins, leave platform unchanged and let pip resolve/fail. + return None + + return desired_cuda + + +def _maybe_override_nvidia_cuda_platform(torch_platform, packages, gpu_infos): + """ + Adjust cu124/cu128 index selection based on pinned torch/torchvision/torchaudio versions. + """ + if not torch_platform or not CUDA_REGEX.match(torch_platform): + return torch_platform + + desired_cuda = _get_cuda_platform_for_pytorch_packages(packages) + if desired_cuda not in ("cu124", "cu128"): + return torch_platform + + current_cuda = torch_platform.split("/", 1)[-1] + if current_cuda == desired_cuda: + return torch_platform + + # Do not demote Blackwell GPUs from cu128 -> cu124; older torch versions won't support them anyway. + if current_cuda == "cu128" and desired_cuda == "cu124": + device_names = set(gpu.device_name for gpu in (gpu_infos or [])) + arch_version = get_nvidia_arch(device_names) if device_names else 0 + if arch_version == 12: + return torch_platform + + return _cuda_platform_with_prefix(torch_platform, desired_cuda) + + def install(packages=[], use_uv=False): """ packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"]. @@ -99,6 +226,7 @@ def install(packages=[], use_uv=False): gpu_infos = get_gpus() torch_platform = get_torch_platform(gpu_infos) + torch_platform = _maybe_override_nvidia_cuda_platform(torch_platform, packages, gpu_infos) cmds = get_install_commands(torch_platform, packages) cmds = get_pip_commands(cmds, use_uv=use_uv) run_commands(cmds)