-
Notifications
You must be signed in to change notification settings - Fork 5
Fix cu128 index selection for pinned Torch versions #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix cu128 index selection for pinned Torch versions #30
Conversation
When a system would normally use cu128, but the requested torch/torchvision/torchaudio versions are capped below the first cu128 wheels, automatically use cu124 instead. Adds a regression test for issue easydiffusion#16.
torchruntime/installer.py
Outdated
|
|
||
| gpu_infos = get_gpus() | ||
| torch_platform = get_torch_platform(gpu_infos) | ||
| torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please move this logic into platform_detection.py, and add an optional packages=[] arg in get_torch_platform()?
That would preserve the separation of concerns, since we're effectively fixing a problem with platform detection, not installation commands.
The unit tests would change accordingly.
torchruntime/installer.py
Outdated
| return name, spec | ||
|
|
||
|
|
||
| def _upper_bound_for_specifier(specifier): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to write a complete version parser ourselves? Isn't there any built-in library inside python that can do this?
This feels like a lot of lines of code just for this purpose.
cmdr2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this up! I've added some comments. Looks fine overall - my comments are mainly around the location of the code, and for trying to reduce the complexity of this change (the complexity-to-value ratio of this PR feels a bit off).
|
I reckon that the version parser is going to give Torchruntime a lot more horse sense. There are other scenarios where user-specified Torch versions may conflict with available packages, e.g. when DirectML is used. |
torchruntime/platform_detection.py
Outdated
|
|
||
|
|
||
| def _adjust_cuda_platform_for_requested_packages(torch_platform, packages): | ||
| if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check can be moved directly inside the discrete function, just before we return cu128.
CUDA isn't integrated, so this check doesn't need to occupy anything more than it needs.
For e.g.:
if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9):
return "cu124"
if _packages_require_cuda_12_4(packages):
return "cu124"
return "cu128"|
@godnight10061 Would this work? Passes the tests as well. import sys
import platform
from packaging.requirements import Requirement
from packaging.version import Version
os_name = platform.system()
arch = platform.machine().lower()
py_version = sys.version_info
_CUDA_12_8_MIN_VERSIONS = {
"torch": Version("2.7.0"),
"torchaudio": Version("2.7.0"),
"torchvision": Version("0.22.0"),
}
def _packages_require_cuda_12_4(packages):
"""Check if any package requires CUDA 12.4 (cu124) instead of CUDA 12.8 (cu128).
Returns True if any package version constraint excludes the minimum CUDA 12.8 version.
"""
if not packages:
return False
for package in packages:
try:
requirement = Requirement(package)
except Exception:
continue
name = requirement.name.lower().replace("_", "-")
threshold = _CUDA_12_8_MIN_VERSIONS.get(name)
# If no threshold for this package, skip it
if not threshold:
continue
# If no version specifier (e.g., just "torch"), it doesn't require cu124
if not requirement.specifier:
continue
# Check if the specifier allows ANY version >= threshold
# Strategy: filter a set of versions >= threshold through the specifier
# If any version >= threshold is allowed, then cu128 works
test_versions = [
str(threshold), # The exact threshold (e.g., 2.7.0)
f"{threshold.major}.{threshold.minor}.{threshold.micro + 1}", # Patch version above (e.g., 2.7.1)
f"{threshold.major}.{threshold.minor + 1}.0", # Minor version above (e.g., 2.8.0)
f"{threshold.major + 1}.0.0", # Major version above (e.g., 3.0.0)
]
allows_threshold_or_higher = any(requirement.specifier.contains(v, prereleases=True) for v in test_versions)
if not allows_threshold_or_higher:
return True
return False
# Test methods
def test_no_packages():
"""Test with no packages."""
assert _packages_require_cuda_12_4([]) == False
assert _packages_require_cuda_12_4(None) == False
def test_unversioned_packages():
"""Test packages without version specifiers."""
assert _packages_require_cuda_12_4(["torch"]) == False
assert _packages_require_cuda_12_4(["torchaudio"]) == False
assert _packages_require_cuda_12_4(["torchvision"]) == False
def test_exact_version_above_threshold():
"""Test exact versions at or above CUDA 12.8 minimum."""
assert _packages_require_cuda_12_4(["torch==2.8.0"]) == False
assert _packages_require_cuda_12_4(["torch==2.7.0"]) == False
assert _packages_require_cuda_12_4(["torch==2.7.1"]) == False
assert _packages_require_cuda_12_4(["torchaudio==2.7.0"]) == False
assert _packages_require_cuda_12_4(["torchvision==0.22.0"]) == False
assert _packages_require_cuda_12_4(["torchvision==0.23.0"]) == False
def test_exact_version_below_threshold():
"""Test exact versions below CUDA 12.8 minimum."""
assert _packages_require_cuda_12_4(["torch==2.6.0"]) == True
assert _packages_require_cuda_12_4(["torch==2.5.0"]) == True
assert _packages_require_cuda_12_4(["torch==2.0.0"]) == True
assert _packages_require_cuda_12_4(["torchaudio==2.6.0"]) == True
assert _packages_require_cuda_12_4(["torchvision==0.21.0"]) == True
assert _packages_require_cuda_12_4(["torchvision==0.20.0"]) == True
def test_compatible_release_operator():
"""Test ~= (compatible release) operator."""
assert _packages_require_cuda_12_4(["torch~=2.5.0"]) == True # allows 2.5.x only
assert _packages_require_cuda_12_4(["torch~=2.6.0"]) == True # allows 2.6.x only
assert _packages_require_cuda_12_4(["torch~=2.7.0"]) == False # allows 2.7.x
assert _packages_require_cuda_12_4(["torch~=2.8.0"]) == False # allows 2.8.x
def test_greater_than_operators():
"""Test > and >= operators."""
assert _packages_require_cuda_12_4(["torch>2.6.0"]) == False # allows 2.7.0+
assert _packages_require_cuda_12_4(["torch>=2.7.0"]) == False
assert _packages_require_cuda_12_4(["torch>=2.8.0"]) == False
assert _packages_require_cuda_12_4(["torch>2.7.0"]) == False
def test_less_than_operators():
"""Test < and <= operators."""
assert _packages_require_cuda_12_4(["torch<2.7.0"]) == True # excludes 2.7.0
assert _packages_require_cuda_12_4(["torch<=2.6.0"]) == True
assert _packages_require_cuda_12_4(["torch<2.8.0"]) == False # includes 2.7.0
assert _packages_require_cuda_12_4(["torch<=2.7.0"]) == False # includes 2.7.0
def test_range_specifiers():
"""Test version ranges."""
assert _packages_require_cuda_12_4(["torch>=2.5.0,<2.7.0"]) == True # excludes 2.7.0
assert _packages_require_cuda_12_4(["torch>=2.5.0,<=2.6.0"]) == True
assert _packages_require_cuda_12_4(["torch>=2.6.0,<2.8.0"]) == False # includes 2.7.0
assert _packages_require_cuda_12_4(["torch>=2.7.0,<3.0.0"]) == False
assert _packages_require_cuda_12_4(["torch>=2.5.0,<3.0.0"]) == False # includes 2.7.0
def test_multiple_packages():
"""Test with multiple packages in the list."""
assert _packages_require_cuda_12_4(["torch==2.8.0", "numpy"]) == False
assert _packages_require_cuda_12_4(["numpy", "torch==2.6.0"]) == True
assert _packages_require_cuda_12_4(["torch==2.8.0", "torchaudio==2.6.0"]) == True # one requires cu124
assert _packages_require_cuda_12_4(["torch==2.7.0", "torchaudio==2.7.0"]) == False
def test_non_tracked_packages():
"""Test packages not in the CUDA 12.8 threshold dict."""
assert _packages_require_cuda_12_4(["numpy==1.24.0"]) == False
assert _packages_require_cuda_12_4(["pandas>=2.0.0"]) == False
assert _packages_require_cuda_12_4(["scikit-learn"]) == False
def test_invalid_package_strings():
"""Test with invalid package specifications."""
assert _packages_require_cuda_12_4(["invalid package spec!!!"]) == False
assert _packages_require_cuda_12_4(["torch==2.6.0", "bad spec", "numpy"]) == True
def test_not_equal_operator():
"""Test != operator."""
assert _packages_require_cuda_12_4(["torch!=2.7.0"]) == False # allows 2.7.1, 2.8.0, etc.
assert _packages_require_cuda_12_4(["torch!=2.6.0"]) == False # allows 2.7.0
def test_wildcard_versions():
"""Test wildcard version specifiers."""
assert _packages_require_cuda_12_4(["torch==2.6.*"]) == True # only 2.6.x
assert _packages_require_cuda_12_4(["torch==2.7.*"]) == False # includes 2.7.0
assert _packages_require_cuda_12_4(["torch==2.*"]) == False # includes 2.7.0
if __name__ == "__main__":
# Run all tests
test_no_packages()
test_unversioned_packages()
test_exact_version_above_threshold()
test_exact_version_below_threshold()
test_compatible_release_operator()
test_greater_than_operators()
test_less_than_operators()
test_range_specifiers()
test_multiple_packages()
test_non_tracked_packages()
test_invalid_package_strings()
test_not_equal_operator()
test_wildcard_versions()
print("All tests passed!") |
|
Thanks @godnight10061 ! Looks good! :) |
Fixes #16.
Problem
On NVIDIA systems torchruntime may pick the
cu128wheel index by default. CUDA 12.8 wheels only exist for:torch/torchaudio>= 2.7.0torchvision>= 0.22.0If a user (or a dependency resolver) pins/caps any of these below those versions (e.g.
torch==2.6.0), torchruntime still usescu128andpipfails withNo matching distribution found.Change
install()now inspects requested package specifiers. When the detected platform iscu128and any requestedtorch/torchvision/torchaudiorequirement has an upper bound below the firstcu128-available release, it automatically demotes the platform tocu124.torch==2.6.0case.Testing
python -m pytest