Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/test_platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ def test_nvidia_gpu_linux(monkeypatch):
assert get_torch_platform(gpu_infos) == expected


def test_nvidia_gpu_demotes_to_cu124_for_pinned_torch_below_2_7(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 11))
monkeypatch.setattr("torchruntime.platform_detection.get_nvidia_arch", lambda device_names: 8.6)

gpu_infos = [GPU(NVIDIA, "NVIDIA", 0x1234, "GeForce", True)]

assert get_torch_platform(gpu_infos) == "cu128"
assert get_torch_platform(gpu_infos, packages=["torch==2.6.0"]) == "cu124"
assert get_torch_platform(gpu_infos, packages=["torch<2.7.0"]) == "cu124"
assert get_torch_platform(gpu_infos, packages=["torch<=2.7.0"]) == "cu128"
assert get_torch_platform(gpu_infos, packages=["torch!=2.7.0"]) == "cu128"
assert get_torch_platform(gpu_infos, packages=["torch>=2.7.0,!=2.7.0,!=2.7.1,<2.8.0"]) == "cu128"
assert get_torch_platform(gpu_infos, packages=["torchvision==0.21.0"]) == "cu124"


def test_nvidia_gpu_mac(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
Expand Down
2 changes: 1 addition & 1 deletion torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def install(packages=[], use_uv=False):
"""

gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
torch_platform = get_torch_platform(gpu_infos, packages=packages)
cmds = get_install_commands(torch_platform, packages)
cmds = get_pip_commands(cmds, use_uv=use_uv)
run_commands(cmds)
53 changes: 49 additions & 4 deletions torchruntime/platform_detection.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,63 @@
import re
import sys
import platform

from packaging.requirements import Requirement
from packaging.version import Version

from .gpu_db import get_nvidia_arch, get_amd_gfx_info
from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK

os_name = platform.system()
arch = platform.machine().lower()
py_version = sys.version_info

_CUDA_12_8_MIN_VERSIONS = {
"torch": Version("2.7.0"),
"torchaudio": Version("2.7.0"),
"torchvision": Version("0.22.0"),
}


def _packages_require_cuda_12_4(packages):
if not packages:
return False

for package in packages:
try:
requirement = Requirement(package)
except Exception:
continue

name = requirement.name.lower().replace("_", "-")
threshold = _CUDA_12_8_MIN_VERSIONS.get(name)
if not threshold or not requirement.specifier:
continue

test_versions = [
threshold,
Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 1}"),
Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 2}"),
Version(f"{threshold.major}.{threshold.minor}.{threshold.micro + 3}"),
Version(f"{threshold.major}.{threshold.minor + 1}.0"),
Version(f"{threshold.major + 1}.0.0"),
]

allows_threshold_or_higher = any(
requirement.specifier.contains(str(version), prereleases=True) for version in test_versions
)
if not allows_threshold_or_higher:
return True

return False

def get_torch_platform(gpu_infos):

def get_torch_platform(gpu_infos, packages=[]):
"""
Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.

Args:
gpu_infos (list of `torchruntime.device_db.GPU` instances)
packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings.

Returns:
str: A string representing the platform to use. Possible values:
Expand Down Expand Up @@ -53,12 +95,12 @@ def get_torch_platform(gpu_infos):
integrated_devices.append(device)

if discrete_devices:
return _get_platform_for_discrete(discrete_devices)
return _get_platform_for_discrete(discrete_devices, packages=packages)

return _get_platform_for_integrated(integrated_devices)


def _get_platform_for_discrete(gpu_infos):
def _get_platform_for_discrete(gpu_infos, packages=None):
vendor_ids = set(gpu.vendor_id for gpu in gpu_infos)

if len(vendor_ids) > 1:
Expand Down Expand Up @@ -126,6 +168,9 @@ def _get_platform_for_discrete(gpu_infos):
if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9):
return "cu124"

if _packages_require_cuda_12_4(packages):
return "cu124"

return "cu128"
elif os_name == "Darwin":
raise NotImplementedError(
Expand Down