Skip to content

Conversation

@godnight10061
Copy link
Contributor

Fixes #16.

Problem

On NVIDIA systems torchruntime may pick the cu128 wheel index by default. CUDA 12.8 wheels only exist for:

  • torch/torchaudio >= 2.7.0
  • torchvision >= 0.22.0

If a user (or a dependency resolver) pins/caps any of these below those versions (e.g. torch==2.6.0), torchruntime still uses cu128 and pip fails with No matching distribution found.

Change

  • install() now inspects requested package specifiers. When the detected platform is cu128 and any requested torch/torchvision/torchaudio requirement has an upper bound below the first cu128-available release, it automatically demotes the platform to cu124.
  • Adds a regression test covering the torch==2.6.0 case.

Testing

  • python -m pytest

When a system would normally use cu128, but the requested torch/torchvision/torchaudio versions are capped below the first cu128 wheels, automatically use cu124 instead. Adds a regression test for issue easydiffusion#16.

gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages)
Copy link
Contributor

@cmdr2 cmdr2 Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please move this logic into platform_detection.py, and add an optional packages=[] arg in get_torch_platform()?

That would preserve the separation of concerns, since we're effectively fixing a problem with platform detection, not installation commands.

The unit tests would change accordingly.

return name, spec


def _upper_bound_for_specifier(specifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to write a complete version parser ourselves? Isn't there any built-in library inside python that can do this?

This feels like a lot of lines of code just for this purpose.

Copy link
Contributor

@cmdr2 cmdr2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this up! I've added some comments. Looks fine overall - my comments are mainly around the location of the code, and for trying to reduce the complexity of this change (the complexity-to-value ratio of this PR feels a bit off).

@iwr-redmond
Copy link
Contributor

I reckon that the version parser is going to give Torchruntime a lot more horse sense. There are other scenarios where user-specified Torch versions may conflict with available packages, e.g. when DirectML is used.

@godnight10061 godnight10061 requested a review from cmdr2 December 24, 2025 04:02


def _adjust_cuda_platform_for_requested_packages(torch_platform, packages):
if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be moved directly inside the discrete function, just before we return cu128.

CUDA isn't integrated, so this check doesn't need to occupy anything more than it needs.

For e.g.:

if (arch_version > 3.7 and arch_version < 7.5) or py_version < (3, 9):
    return "cu124"

if _packages_require_cuda_12_4(packages):
    return "cu124"

return "cu128"

@cmdr2
Copy link
Contributor

cmdr2 commented Dec 24, 2025

@godnight10061 Would this work? Passes the tests as well.

import sys
import platform

from packaging.requirements import Requirement
from packaging.version import Version

os_name = platform.system()
arch = platform.machine().lower()
py_version = sys.version_info

_CUDA_12_8_MIN_VERSIONS = {
    "torch": Version("2.7.0"),
    "torchaudio": Version("2.7.0"),
    "torchvision": Version("0.22.0"),
}


def _packages_require_cuda_12_4(packages):
    """Check if any package requires CUDA 12.4 (cu124) instead of CUDA 12.8 (cu128).

    Returns True if any package version constraint excludes the minimum CUDA 12.8 version.
    """
    if not packages:
        return False

    for package in packages:
        try:
            requirement = Requirement(package)
        except Exception:
            continue

        name = requirement.name.lower().replace("_", "-")
        threshold = _CUDA_12_8_MIN_VERSIONS.get(name)

        # If no threshold for this package, skip it
        if not threshold:
            continue

        # If no version specifier (e.g., just "torch"), it doesn't require cu124
        if not requirement.specifier:
            continue

        # Check if the specifier allows ANY version >= threshold
        # Strategy: filter a set of versions >= threshold through the specifier
        # If any version >= threshold is allowed, then cu128 works
        test_versions = [
            str(threshold),  # The exact threshold (e.g., 2.7.0)
            f"{threshold.major}.{threshold.minor}.{threshold.micro + 1}",  # Patch version above (e.g., 2.7.1)
            f"{threshold.major}.{threshold.minor + 1}.0",  # Minor version above (e.g., 2.8.0)
            f"{threshold.major + 1}.0.0",  # Major version above (e.g., 3.0.0)
        ]

        allows_threshold_or_higher = any(requirement.specifier.contains(v, prereleases=True) for v in test_versions)

        if not allows_threshold_or_higher:
            return True

    return False


# Test methods
def test_no_packages():
    """Test with no packages."""
    assert _packages_require_cuda_12_4([]) == False
    assert _packages_require_cuda_12_4(None) == False


def test_unversioned_packages():
    """Test packages without version specifiers."""
    assert _packages_require_cuda_12_4(["torch"]) == False
    assert _packages_require_cuda_12_4(["torchaudio"]) == False
    assert _packages_require_cuda_12_4(["torchvision"]) == False


def test_exact_version_above_threshold():
    """Test exact versions at or above CUDA 12.8 minimum."""
    assert _packages_require_cuda_12_4(["torch==2.8.0"]) == False
    assert _packages_require_cuda_12_4(["torch==2.7.0"]) == False
    assert _packages_require_cuda_12_4(["torch==2.7.1"]) == False
    assert _packages_require_cuda_12_4(["torchaudio==2.7.0"]) == False
    assert _packages_require_cuda_12_4(["torchvision==0.22.0"]) == False
    assert _packages_require_cuda_12_4(["torchvision==0.23.0"]) == False


def test_exact_version_below_threshold():
    """Test exact versions below CUDA 12.8 minimum."""
    assert _packages_require_cuda_12_4(["torch==2.6.0"]) == True
    assert _packages_require_cuda_12_4(["torch==2.5.0"]) == True
    assert _packages_require_cuda_12_4(["torch==2.0.0"]) == True
    assert _packages_require_cuda_12_4(["torchaudio==2.6.0"]) == True
    assert _packages_require_cuda_12_4(["torchvision==0.21.0"]) == True
    assert _packages_require_cuda_12_4(["torchvision==0.20.0"]) == True


def test_compatible_release_operator():
    """Test ~= (compatible release) operator."""
    assert _packages_require_cuda_12_4(["torch~=2.5.0"]) == True  # allows 2.5.x only
    assert _packages_require_cuda_12_4(["torch~=2.6.0"]) == True  # allows 2.6.x only
    assert _packages_require_cuda_12_4(["torch~=2.7.0"]) == False  # allows 2.7.x
    assert _packages_require_cuda_12_4(["torch~=2.8.0"]) == False  # allows 2.8.x


def test_greater_than_operators():
    """Test > and >= operators."""
    assert _packages_require_cuda_12_4(["torch>2.6.0"]) == False  # allows 2.7.0+
    assert _packages_require_cuda_12_4(["torch>=2.7.0"]) == False
    assert _packages_require_cuda_12_4(["torch>=2.8.0"]) == False
    assert _packages_require_cuda_12_4(["torch>2.7.0"]) == False


def test_less_than_operators():
    """Test < and <= operators."""
    assert _packages_require_cuda_12_4(["torch<2.7.0"]) == True  # excludes 2.7.0
    assert _packages_require_cuda_12_4(["torch<=2.6.0"]) == True
    assert _packages_require_cuda_12_4(["torch<2.8.0"]) == False  # includes 2.7.0
    assert _packages_require_cuda_12_4(["torch<=2.7.0"]) == False  # includes 2.7.0


def test_range_specifiers():
    """Test version ranges."""
    assert _packages_require_cuda_12_4(["torch>=2.5.0,<2.7.0"]) == True  # excludes 2.7.0
    assert _packages_require_cuda_12_4(["torch>=2.5.0,<=2.6.0"]) == True
    assert _packages_require_cuda_12_4(["torch>=2.6.0,<2.8.0"]) == False  # includes 2.7.0
    assert _packages_require_cuda_12_4(["torch>=2.7.0,<3.0.0"]) == False
    assert _packages_require_cuda_12_4(["torch>=2.5.0,<3.0.0"]) == False  # includes 2.7.0


def test_multiple_packages():
    """Test with multiple packages in the list."""
    assert _packages_require_cuda_12_4(["torch==2.8.0", "numpy"]) == False
    assert _packages_require_cuda_12_4(["numpy", "torch==2.6.0"]) == True
    assert _packages_require_cuda_12_4(["torch==2.8.0", "torchaudio==2.6.0"]) == True  # one requires cu124
    assert _packages_require_cuda_12_4(["torch==2.7.0", "torchaudio==2.7.0"]) == False


def test_non_tracked_packages():
    """Test packages not in the CUDA 12.8 threshold dict."""
    assert _packages_require_cuda_12_4(["numpy==1.24.0"]) == False
    assert _packages_require_cuda_12_4(["pandas>=2.0.0"]) == False
    assert _packages_require_cuda_12_4(["scikit-learn"]) == False


def test_invalid_package_strings():
    """Test with invalid package specifications."""
    assert _packages_require_cuda_12_4(["invalid package spec!!!"]) == False
    assert _packages_require_cuda_12_4(["torch==2.6.0", "bad spec", "numpy"]) == True


def test_not_equal_operator():
    """Test != operator."""
    assert _packages_require_cuda_12_4(["torch!=2.7.0"]) == False  # allows 2.7.1, 2.8.0, etc.
    assert _packages_require_cuda_12_4(["torch!=2.6.0"]) == False  # allows 2.7.0


def test_wildcard_versions():
    """Test wildcard version specifiers."""
    assert _packages_require_cuda_12_4(["torch==2.6.*"]) == True  # only 2.6.x
    assert _packages_require_cuda_12_4(["torch==2.7.*"]) == False  # includes 2.7.0
    assert _packages_require_cuda_12_4(["torch==2.*"]) == False  # includes 2.7.0


if __name__ == "__main__":
    # Run all tests
    test_no_packages()
    test_unversioned_packages()
    test_exact_version_above_threshold()
    test_exact_version_below_threshold()
    test_compatible_release_operator()
    test_greater_than_operators()
    test_less_than_operators()
    test_range_specifiers()
    test_multiple_packages()
    test_non_tracked_packages()
    test_invalid_package_strings()
    test_not_equal_operator()
    test_wildcard_versions()

    print("All tests passed!")

@cmdr2 cmdr2 merged commit fce59bd into easydiffusion:main Dec 25, 2025
2 checks passed
@cmdr2
Copy link
Contributor

cmdr2 commented Dec 25, 2025

Thanks @godnight10061 ! Looks good! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Difficulties with NVIDIA 50xx Support

3 participants