From 5327187dc58f3f830a417d17b1e45b041f77b887 Mon Sep 17 00:00:00 2001 From: Godnight1006 Date: Tue, 23 Dec 2025 09:04:13 +0800 Subject: [PATCH 1/3] Fix CUDA platform selection for pinned torch versions When a system would normally use cu128, but the requested torch/torchvision/torchaudio versions are capped below the first cu128 wheels, automatically use cu124 instead. Adds a regression test for issue #16. --- tests/test_installer.py | 24 +++++++- torchruntime/installer.py | 126 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 1 deletion(-) diff --git a/tests/test_installer.py b/tests/test_installer.py index ade0358..1b2221a 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -1,7 +1,7 @@ import sys import pytest from unittest.mock import patch -from torchruntime.installer import get_install_commands, get_pip_commands, run_commands +from torchruntime.installer import get_install_commands, get_pip_commands, run_commands, install def test_empty_args(): @@ -125,3 +125,25 @@ def test_run_commands(): # Check that subprocess.run was called with the correct arguments mock_run.assert_any_call(cmds[0]) mock_run.assert_any_call(cmds[1]) + + +def test_install_demotes_cu128_to_cu124_for_torch_2_6(monkeypatch): + # Simulate a system where the detected platform would be cu128. + monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: ["dummy_gpu"]) + monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu128") + + seen = {} + + def fake_get_install_commands(torch_platform, packages): + seen["torch_platform"] = torch_platform + seen["packages"] = packages + return [packages] + + monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands) + monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds) + monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None) + + install(["torch==2.6.0"]) + + assert seen["packages"] == ["torch==2.6.0"] + assert seen["torch_platform"] == "cu124" diff --git a/torchruntime/installer.py b/torchruntime/installer.py index 475244f..dee0515 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -13,6 +13,131 @@ CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") +_CUDA_12_8_PLATFORM = "cu128" +_CUDA_12_4_PLATFORM = "cu124" +_CUDA_12_8_MIN_VERSIONS = { + "torch": (2, 7, 0), + "torchaudio": (2, 7, 0), + "torchvision": (0, 22, 0), +} + + +def _parse_version_segments(text): + text = text.strip().split("+", 1)[0] + segments = [] + for part in text.split("."): + m = re.match(r"^(\d+)", part) + if not m: + break + segments.append(int(m.group(1))) + return segments + + +def _as_version_tuple(version_segments): + padded = list(version_segments[:3]) + while len(padded) < 3: + padded.append(0) + return tuple(padded) + + +def _version_lt(a, b): + return _as_version_tuple(a) < _as_version_tuple(b) + + +def _version_le(a, b): + return _as_version_tuple(a) <= _as_version_tuple(b) + + +def _get_requirement_name_and_specifier(requirement): + req = requirement.strip() + if not req or req.startswith("-") or "@" in req: + return None, None + + match = re.match(r"^([A-Za-z0-9][A-Za-z0-9_.-]*)(?:\[[^\]]+\])?", req) + if not match: + return None, None + + name = match.group(1).lower().replace("_", "-") + spec = req[match.end() :].split(";", 1)[0].strip() + return name, spec + + +def _upper_bound_for_specifier(specifier): + """ + Returns (upper_bound_segments, is_inclusive) for specifiers that impose an upper bound, + or (None, None) if there is no upper bound. + """ + + s = specifier.strip() + + if s.startswith("=="): + value = s[2:].strip() + if "*" in value: + prefix = value.split("*", 1)[0].rstrip(".") + prefix_segments = _parse_version_segments(prefix) + if not prefix_segments: + return None, None + upper = list(prefix_segments) + upper[-1] += 1 + upper.append(0) + return upper, False + + return _parse_version_segments(value), True + + if s.startswith("<="): + return _parse_version_segments(s[2:].strip()), True + + if s.startswith("<"): + return _parse_version_segments(s[1:].strip()), False + + if s.startswith("~="): + value_segments = _parse_version_segments(s[2:].strip()) + if len(value_segments) < 2: + return None, None + upper = list(value_segments[:-1]) + upper[-1] += 1 + upper.append(0) + return upper, False + + return None, None + + +def _packages_require_cuda_12_4(packages): + """ + True if the requested torch package versions cannot be satisfied by the CUDA 12.8 wheel index. + + This happens when a package is pinned (or capped) below the first version that has CUDA 12.8 wheels. + """ + + if not packages: + return False + + for package in packages: + name, spec = _get_requirement_name_and_specifier(package) + if not name or name not in _CUDA_12_8_MIN_VERSIONS or not spec: + continue + + threshold = _CUDA_12_8_MIN_VERSIONS[name] + for raw in spec.split(","): + upper, inclusive = _upper_bound_for_specifier(raw) + if not upper: + continue + + if inclusive: + if _version_lt(upper, threshold): + return True + else: + if _version_le(upper, threshold): + return True + + return False + + +def _adjust_cuda_platform_for_requested_packages(torch_platform, packages): + if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages): + return _CUDA_12_4_PLATFORM + return torch_platform + def get_install_commands(torch_platform, packages): """ @@ -99,6 +224,7 @@ def install(packages=[], use_uv=False): gpu_infos = get_gpus() torch_platform = get_torch_platform(gpu_infos) + torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages) cmds = get_install_commands(torch_platform, packages) cmds = get_pip_commands(cmds, use_uv=use_uv) run_commands(cmds) From 7f961adbbfef7d557956bdb9985aed18c09e92c9 Mon Sep 17 00:00:00 2001 From: Godnight1006 Date: Wed, 24 Dec 2025 11:47:18 +0800 Subject: [PATCH 2/3] Move CUDA package-based demotion to platform detection --- tests/test_installer.py | 24 +----- tests/test_platform_detection.py | 12 +++ torchruntime/installer.py | 128 +---------------------------- torchruntime/platform_detection.py | 103 ++++++++++++++++++++++- 4 files changed, 114 insertions(+), 153 deletions(-) diff --git a/tests/test_installer.py b/tests/test_installer.py index 1b2221a..ade0358 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -1,7 +1,7 @@ import sys import pytest from unittest.mock import patch -from torchruntime.installer import get_install_commands, get_pip_commands, run_commands, install +from torchruntime.installer import get_install_commands, get_pip_commands, run_commands def test_empty_args(): @@ -125,25 +125,3 @@ def test_run_commands(): # Check that subprocess.run was called with the correct arguments mock_run.assert_any_call(cmds[0]) mock_run.assert_any_call(cmds[1]) - - -def test_install_demotes_cu128_to_cu124_for_torch_2_6(monkeypatch): - # Simulate a system where the detected platform would be cu128. - monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: ["dummy_gpu"]) - monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu128") - - seen = {} - - def fake_get_install_commands(torch_platform, packages): - seen["torch_platform"] = torch_platform - seen["packages"] = packages - return [packages] - - monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands) - monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds) - monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None) - - install(["torch==2.6.0"]) - - assert seen["packages"] == ["torch==2.6.0"] - assert seen["torch_platform"] == "cu124" diff --git a/tests/test_platform_detection.py b/tests/test_platform_detection.py index d26c107..d3af40a 100644 --- a/tests/test_platform_detection.py +++ b/tests/test_platform_detection.py @@ -121,6 +121,18 @@ def test_nvidia_gpu_linux(monkeypatch): assert get_torch_platform(gpu_infos) == expected +def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows") + monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64") + monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 11)) + monkeypatch.setattr("torchruntime.platform_detection.get_nvidia_arch", lambda device_names: 8.6) + + gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)] + + assert get_torch_platform(gpu_infos) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124" + + def test_nvidia_gpu_mac(monkeypatch): monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin") monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64") diff --git a/torchruntime/installer.py b/torchruntime/installer.py index dee0515..d7f17d3 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -13,131 +13,6 @@ CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") -_CUDA_12_8_PLATFORM = "cu128" -_CUDA_12_4_PLATFORM = "cu124" -_CUDA_12_8_MIN_VERSIONS = { - "torch": (2, 7, 0), - "torchaudio": (2, 7, 0), - "torchvision": (0, 22, 0), -} - - -def _parse_version_segments(text): - text = text.strip().split("+", 1)[0] - segments = [] - for part in text.split("."): - m = re.match(r"^(\d+)", part) - if not m: - break - segments.append(int(m.group(1))) - return segments - - -def _as_version_tuple(version_segments): - padded = list(version_segments[:3]) - while len(padded) < 3: - padded.append(0) - return tuple(padded) - - -def _version_lt(a, b): - return _as_version_tuple(a) < _as_version_tuple(b) - - -def _version_le(a, b): - return _as_version_tuple(a) <= _as_version_tuple(b) - - -def _get_requirement_name_and_specifier(requirement): - req = requirement.strip() - if not req or req.startswith("-") or "@" in req: - return None, None - - match = re.match(r"^([A-Za-z0-9][A-Za-z0-9_.-]*)(?:\[[^\]]+\])?", req) - if not match: - return None, None - - name = match.group(1).lower().replace("_", "-") - spec = req[match.end() :].split(";", 1)[0].strip() - return name, spec - - -def _upper_bound_for_specifier(specifier): - """ - Returns (upper_bound_segments, is_inclusive) for specifiers that impose an upper bound, - or (None, None) if there is no upper bound. - """ - - s = specifier.strip() - - if s.startswith("=="): - value = s[2:].strip() - if "*" in value: - prefix = value.split("*", 1)[0].rstrip(".") - prefix_segments = _parse_version_segments(prefix) - if not prefix_segments: - return None, None - upper = list(prefix_segments) - upper[-1] += 1 - upper.append(0) - return upper, False - - return _parse_version_segments(value), True - - if s.startswith("<="): - return _parse_version_segments(s[2:].strip()), True - - if s.startswith("<"): - return _parse_version_segments(s[1:].strip()), False - - if s.startswith("~="): - value_segments = _parse_version_segments(s[2:].strip()) - if len(value_segments) < 2: - return None, None - upper = list(value_segments[:-1]) - upper[-1] += 1 - upper.append(0) - return upper, False - - return None, None - - -def _packages_require_cuda_12_4(packages): - """ - True if the requested torch package versions cannot be satisfied by the CUDA 12.8 wheel index. - - This happens when a package is pinned (or capped) below the first version that has CUDA 12.8 wheels. - """ - - if not packages: - return False - - for package in packages: - name, spec = _get_requirement_name_and_specifier(package) - if not name or name not in _CUDA_12_8_MIN_VERSIONS or not spec: - continue - - threshold = _CUDA_12_8_MIN_VERSIONS[name] - for raw in spec.split(","): - upper, inclusive = _upper_bound_for_specifier(raw) - if not upper: - continue - - if inclusive: - if _version_lt(upper, threshold): - return True - else: - if _version_le(upper, threshold): - return True - - return False - - -def _adjust_cuda_platform_for_requested_packages(torch_platform, packages): - if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages): - return _CUDA_12_4_PLATFORM - return torch_platform - def get_install_commands(torch_platform, packages): """ @@ -223,8 +98,7 @@ def install(packages=[], use_uv=False): """ gpu_infos = get_gpus() - torch_platform = get_torch_platform(gpu_infos) - torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages) + torch_platform = get_torch_platform(gpu_infos, packages=packages) cmds = get_install_commands(torch_platform, packages) cmds = get_pip_commands(cmds, use_uv=use_uv) run_commands(cmds) diff --git a/torchruntime/platform_detection.py b/torchruntime/platform_detection.py index 209e0e2..9f4ae7f 100644 --- a/torchruntime/platform_detection.py +++ b/torchruntime/platform_detection.py @@ -2,6 +2,9 @@ import sys import platform +from packaging.requirements import Requirement +from packaging.version import Version + from .gpu_db import get_nvidia_arch, get_amd_gfx_info from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK @@ -9,13 +12,105 @@ arch = platform.machine().lower() py_version = sys.version_info +_CUDA_12_8_PLATFORM = "cu128" +_CUDA_12_4_PLATFORM = "cu124" +_CUDA_12_8_MIN_VERSIONS = { + "torch": Version("2.7.0"), + "torchaudio": Version("2.7.0"), + "torchvision": Version("0.22.0"), +} + + +def _parse_release_segments(text): + segments = [] + for part in text.split("."): + match = re.match(r"^(\d+)", part) + if not match: + break + segments.append(int(match.group(1))) + return segments + + +def _upper_bound_for_specifier(specifier): + operator = specifier.operator + version = specifier.version + + if operator == "<": + return Version(version), False + if operator == "<=": + return Version(version), True + if operator == "==": + if "*" in version: + prefix = version.split("*", 1)[0].rstrip(".") + prefix_segments = _parse_release_segments(prefix) + if not prefix_segments: + return None, None + prefix_segments[-1] += 1 + upper = Version(".".join(str(s) for s in prefix_segments)) + return upper, False + return Version(version), True + if operator == "~=": + release_segments = _parse_release_segments(version) + if len(release_segments) < 2: + return None, None + bump_index = len(release_segments) - 2 + upper_segments = release_segments[: bump_index + 1] + upper_segments[bump_index] += 1 + upper = Version(".".join(str(s) for s in upper_segments)) + return upper, False + + return None, None + + +def _packages_require_cuda_12_4(packages): + if not packages: + return False + + for package in packages: + try: + requirement = Requirement(package) + except Exception: + continue + + name = requirement.name.lower().replace("_", "-") + threshold = _CUDA_12_8_MIN_VERSIONS.get(name) + if not threshold or not requirement.specifier: + continue + + threshold_allowed = None + for specifier in requirement.specifier: + upper, inclusive = _upper_bound_for_specifier(specifier) + if not upper: + continue + + if upper < threshold: + return True + + if upper == threshold and not inclusive: + return True + + if upper == threshold and inclusive: + if threshold_allowed is None: + threshold_allowed = requirement.specifier.contains(threshold, prereleases=True) + if not threshold_allowed: + return True + + return False + + +def _adjust_cuda_platform_for_requested_packages(torch_platform, packages): + if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages): + return _CUDA_12_4_PLATFORM + return torch_platform + -def get_torch_platform(gpu_infos): +def get_torch_platform(gpu_infos, packages=[]): """ Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information. Args: gpu_infos (list of `torchruntime.device_db.GPU` instances) + packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings. Returns: str: A string representing the platform to use. Possible values: @@ -53,9 +148,11 @@ def get_torch_platform(gpu_infos): integrated_devices.append(device) if discrete_devices: - return _get_platform_for_discrete(discrete_devices) + torch_platform = _get_platform_for_discrete(discrete_devices) + return _adjust_cuda_platform_for_requested_packages(torch_platform, packages) - return _get_platform_for_integrated(integrated_devices) + torch_platform = _get_platform_for_integrated(integrated_devices) + return _adjust_cuda_platform_for_requested_packages(torch_platform, packages) def _get_platform_for_discrete(gpu_infos): From 9ffc57a0b790cd3004c6225e9f9bf64d8c238039 Mon Sep 17 00:00:00 2001 From: Godnight1006 Date: Wed, 24 Dec 2025 21:04:44 +0800 Subject: [PATCH 3/3] Inline cu128 demotion check --- tests/test_platform_detection.py | 5 ++ torchruntime/platform_detection.py | 92 +++++++----------------------- 2 files changed, 25 insertions(+), 72 deletions(-) diff --git a/tests/test_platform_detection.py b/tests/test_platform_detection.py index d3af40a..11de967 100644 --- a/tests/test_platform_detection.py +++ b/tests/test_platform_detection.py @@ -131,6 +131,11 @@ def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch): assert get_torch_platform(gpu_infos) == "cu128" assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124" + assert get_torch_platform(gpu_infos, packages=["torch<2.7.0"]) == "cu124" + assert get_torch_platform(gpu_infos, packages=["torch<=2.7.0"]) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torch!=2.7.0"]) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torch>=2.7.0,!=2.7.0,!=2.7.1,<2.8.0"]) == "cu128" + assert get_torch_platform(gpu_infos, packages=["torchvision==0.21.0"]) == "cu124" def test_nvidia_gpu_mac(monkeypatch): diff --git a/torchruntime/platform_detection.py b/torchruntime/platform_detection.py index 9f4ae7f..f5b785f 100644 --- a/torchruntime/platform_detection.py +++ b/torchruntime/platform_detection.py @@ -1,4 +1,3 @@ -import re import sys import platform @@ -12,8 +11,6 @@ arch = platform.machine().lower() py_version = sys.version_info -_CUDA_12_8_PLATFORM = "cu128" -_CUDA_12_4_PLATFORM = "cu124" _CUDA_12_8_MIN_VERSIONS = { "torch": Version("2.7.0"), "torchaudio": Version("2.7.0"), @@ -21,47 +18,6 @@ } -def _parse_release_segments(text): - segments = [] - for part in text.split("."): - match = re.match(r"^(\d+)", part) - if not match: - break - segments.append(int(match.group(1))) - return segments - - -def _upper_bound_for_specifier(specifier): - operator = specifier.operator - version = specifier.version - - if operator == "<": - return Version(version), False - if operator == "<=": - return Version(version), True - if operator == "==": - if "*" in version: - prefix = version.split("*", 1)[0].rstrip(".") - prefix_segments = _parse_release_segments(prefix) - if not prefix_segments: - return None, None - prefix_segments[-1] += 1 - upper = Version(".".join(str(s) for s in prefix_segments)) - return upper, False - return Version(version), True - if operator == "~=": - release_segments = _parse_release_segments(version) - if len(release_segments) < 2: - return None, None - bump_index = len(release_segments) - 2 - upper_segments = release_segments[: bump_index + 1] - upper_segments[bump_index] += 1 - upper = Version(".".join(str(s) for s in upper_segments)) - return upper, False - - return None, None - - def _packages_require_cuda_12_4(packages): if not packages: return False @@ -77,33 +33,24 @@ def _packages_require_cuda_12_4(packages): if not threshold or not requirement.specifier: continue - threshold_allowed = None - for specifier in requirement.specifier: - upper, inclusive = _upper_bound_for_specifier(specifier) - if not upper: - continue - - if upper < threshold: - return True - - if upper == threshold and not inclusive: - return True - - if upper == threshold and inclusive: - if threshold_allowed is None: - threshold_allowed = requirement.specifier.contains(threshold, prereleases=True) - if not threshold_allowed: - return True + test_versions = [ + threshold, + Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 1}"), + Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 2}"), + Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 3}"), + Version(f"{threshold.major}.{threshold.minor + 1}.0"), + Version(f"{threshold.major + 1}.0.0"), + ] + + allows_threshold_or_higher = any( + requirement.specifier.contains(str(version), prereleases=True) for version in test_versions + ) + if not allows_threshold_or_higher: + return True return False -def _adjust_cuda_platform_for_requested_packages(torch_platform, packages): - if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages): - return _CUDA_12_4_PLATFORM - return torch_platform - - def get_torch_platform(gpu_infos, packages=[]): """ Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information. @@ -148,14 +95,12 @@ def get_torch_platform(gpu_infos, packages=[]): integrated_devices.append(device) if discrete_devices: - torch_platform = _get_platform_for_discrete(discrete_devices) - return _adjust_cuda_platform_for_requested_packages(torch_platform, packages) + return _get_platform_for_discrete(discrete_devices, packages=packages) - torch_platform = _get_platform_for_integrated(integrated_devices) - return _adjust_cuda_platform_for_requested_packages(torch_platform, packages) + return _get_platform_for_integrated(integrated_devices) -def _get_platform_for_discrete(gpu_infos): +def _get_platform_for_discrete(gpu_infos, packages=None): vendor_ids = set(gpu.vendor_id for gpu in gpu_infos) if len(vendor_ids) > 1: @@ -223,6 +168,9 @@ def _get_platform_for_discrete(gpu_infos): if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9): return "cu124" + if _packages_require_cuda_12_4(packages): + return "cu124" + return "cu128" elif os_name == "Darwin": raise NotImplementedError(