From 9816abac738165ab95ff329a92e6f1ea2c2fb250 Mon Sep 17 00:00:00 2001 From: Godnight1006 Date: Fri, 26 Dec 2025 14:09:06 +0800 Subject: [PATCH 1/3] feat: install Triton on more platforms --- API.md | 2 ++ README.md | 2 ++ tests/test_installer.py | 40 ++++++++++++++++++++++++++++++++++++--- torchruntime/installer.py | 21 ++++++++++++++++++-- 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/API.md b/API.md index 7180da5..a26fb52 100644 --- a/API.md +++ b/API.md @@ -13,6 +13,8 @@ Or you can use the library: torchruntime.install(["torch", "torchvision<0.20"]) ``` +On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`). + ## Get device info You can use the device database built into `torchruntime` for your projects: ```py diff --git a/README.md b/README.md index c5e6c5f..6108c95 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ Supports Windows, Linux, and Mac. This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options. +On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`). + **Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv` ### Step 2. Configure torch diff --git a/tests/test_installer.py b/tests/test_installer.py index ade0358..ab8856f 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -16,20 +16,38 @@ def test_cpu_platform(): assert result == [packages] -def test_cuda_platform(): +def test_cuda_platform(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] result = get_install_commands("cu112", packages) expected_url = "https://download.pytorch.org/whl/cu112" assert result == [packages + ["--index-url", expected_url]] -def test_cuda_nightly_platform(): +def test_cuda_platform_windows_installs_triton(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Windows") + packages = ["torch", "torchvision"] + result = get_install_commands("cu112", packages) + expected_url = "https://download.pytorch.org/whl/cu112" + assert result == [packages + ["--index-url", expected_url], ["triton-windows"]] + + +def test_cuda_nightly_platform(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] result = get_install_commands("nightly/cu112", packages) expected_url = "https://download.pytorch.org/whl/nightly/cu112" assert result == [packages + ["--index-url", expected_url]] +def test_cuda_nightly_platform_windows_installs_triton(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Windows") + packages = ["torch", "torchvision"] + result = get_install_commands("nightly/cu112", packages) + expected_url = "https://download.pytorch.org/whl/nightly/cu112" + assert result == [packages + ["--index-url", expected_url], ["triton-windows"]] + + def test_rocm_platform(): packages = ["torch", "torchvision"] result = get_install_commands("rocm4.2", packages) @@ -37,6 +55,18 @@ def test_rocm_platform(): assert result == [packages + ["--index-url", expected_url]] +def test_rocm_platform_linux_installs_triton(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") + packages = ["torch", "torchvision"] + result = get_install_commands("rocm6.2", packages) + expected_url = "https://download.pytorch.org/whl/rocm6.2" + triton_index_url = "https://download.pytorch.org/whl" + assert result == [ + packages + ["--index-url", expected_url], + ["pytorch-triton-rocm", "--index-url", triton_index_url], + ] + + def test_xpu_platform_windows_with_torch_only(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Windows") packages = ["torch"] @@ -60,7 +90,11 @@ def test_xpu_platform_linux(monkeypatch): packages = ["torch", "torchvision"] result = get_install_commands("xpu", packages) expected_url = "https://download.pytorch.org/whl/test/xpu" - assert result == [packages + ["--index-url", expected_url]] + triton_index_url = "https://download.pytorch.org/whl" + assert result == [ + packages + ["--index-url", expected_url], + ["pytorch-triton-xpu", "--index-url", triton_index_url], + ] def test_directml_platform(): diff --git a/torchruntime/installer.py b/torchruntime/installer.py index d7f17d3..1153188 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -12,6 +12,7 @@ PIP_PREFIX = [sys.executable, "-m", "pip", "install"] CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") +ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P\d+)\.(?P\d+)$") def get_install_commands(torch_platform, packages): @@ -43,6 +44,9 @@ def get_install_commands(torch_platform, packages): - For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds. - For "directml", the "torch-directml" package is returned as part of the installation commands. - For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands. + - For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels). + - For Linux ROCm 6.x, the function also installs "pytorch-triton-rocm". + - For Linux XPU, the function also installs "pytorch-triton-xpu". """ if not packages: packages = ["torch", "torchaudio", "torchvision"] @@ -52,7 +56,17 @@ def get_install_commands(torch_platform, packages): if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform): index_url = f"https://download.pytorch.org/whl/{torch_platform}" - return [packages + ["--index-url", index_url]] + cmds = [packages + ["--index-url", index_url]] + + if os_name == "Windows" and CUDA_REGEX.match(torch_platform): + cmds.append(["triton-windows"]) + + if os_name == "Linux" and ROCM_REGEX.match(torch_platform): + match = ROCM_VERSION_REGEX.match(torch_platform) + if match and int(match.group("major")) >= 6: + cmds.append(["pytorch-triton-rocm", "--index-url", "https://download.pytorch.org/whl"]) + + return cmds if torch_platform == "xpu": if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages): @@ -65,7 +79,10 @@ def get_install_commands(torch_platform, packages): else: index_url = f"https://download.pytorch.org/whl/test/{torch_platform}" - return [packages + ["--index-url", index_url]] + cmds = [packages + ["--index-url", index_url]] + if os_name == "Linux": + cmds.append(["pytorch-triton-xpu", "--index-url", "https://download.pytorch.org/whl"]) + return cmds if torch_platform == "directml": return [["torch-directml"], packages] From 5e76c4c3669fb37b09eb48f75b58705effac92f7 Mon Sep 17 00:00:00 2001 From: Godnight1006 Date: Fri, 26 Dec 2025 14:34:10 +0800 Subject: [PATCH 2/3] fix: declare packaging dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a55ca5f..c193e07 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ import setuptools setuptools.setup( - install_requires=[], + install_requires=["packaging"], ) From 1cacbf81a923decffd64b8f549b57b92a79e9174 Mon Sep 17 00:00:00 2001 From: Godnight1006 Date: Sat, 27 Dec 2025 11:21:03 +0800 Subject: [PATCH 3/3] feat: add torch.compile triton self-test --- API.md | 7 +++ README.md | 2 +- torchruntime/__main__.py | 5 +- torchruntime/utils/torch_test/__init__.py | 77 ++++++++++++++++++++++- 4 files changed, 86 insertions(+), 5 deletions(-) diff --git a/API.md b/API.md index a26fb52..07992d6 100644 --- a/API.md +++ b/API.md @@ -15,6 +15,13 @@ torchruntime.install(["torch", "torchvision<0.20"]) On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`). +## Test torch +Run: +`python -m torchruntime test` + +To specifically verify `torch.compile` / Triton: +`python -m torchruntime test compile` + ## Get device info You can use the device database built into `torchruntime` for your projects: ```py diff --git a/README.md b/README.md index 6108c95..4bc8787 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ torchruntime.configure() ``` ### (Optional) Step 3. Test torch -Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly. +Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly (including a `torch.compile` / Triton check on CUDA/XPU systems). You can also run `python -m torchruntime test compile` to run only the compile check. ## Customizing packages By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform. diff --git a/torchruntime/__main__.py b/torchruntime/__main__.py index 97d1acd..fefd567 100644 --- a/torchruntime/__main__.py +++ b/torchruntime/__main__.py @@ -10,7 +10,7 @@ def print_usage(entry_command: str): Commands: install Install PyTorch packages - test [subcommand] Run tests (subcommands: all, devices, math, functions) + test [subcommand] Run tests (subcommands: all, import, devices, compile, math, functions) --help Show this help message Examples: @@ -20,10 +20,11 @@ def print_usage(entry_command: str): {entry_command} install --uv torch>=2.0.0 torchaudio {entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0 - {entry_command} test # Runs all tests (import, devices, math, functions) + {entry_command} test # Runs all tests (import, devices, compile, math, functions) {entry_command} test all # Same as above {entry_command} test import # Test only import {entry_command} test devices # Test only devices + {entry_command} test compile # Test torch.compile (Triton) {entry_command} test math # Test only math {entry_command} test functions # Test only functions diff --git a/torchruntime/utils/torch_test/__init__.py b/torchruntime/utils/torch_test/__init__.py index dafe3d8..a407c9f 100644 --- a/torchruntime/utils/torch_test/__init__.py +++ b/torchruntime/utils/torch_test/__init__.py @@ -1,6 +1,8 @@ +import importlib.util +import platform import time -from ..torch_device_utils import get_installed_torch_platform, get_device_count, get_device_name, get_device +from ..torch_device_utils import get_device, get_device_count, get_device_name, get_installed_torch_platform def test(subcommand): @@ -16,7 +18,7 @@ def test(subcommand): def test_all(): - for fn in (test_import, test_devices, test_math, test_functions): + for fn in (test_import, test_devices, test_compile, test_math, test_functions): fn() print("") @@ -101,3 +103,74 @@ def test_functions(): t.run_all_tests() print("--- / FUNCTIONAL TEST ---") + + +def test_compile(): + print("--- COMPILE TEST ---") + + try: + import torch + except ImportError: + print("torch.compile: SKIPPED (torch not installed)") + print("--- / COMPILE TEST ---") + return + + if not hasattr(torch, "compile"): + print("torch.compile: SKIPPED (requires torch>=2.0)") + print("--- / COMPILE TEST ---") + return + + torch_platform_name, _ = get_installed_torch_platform() + if torch_platform_name not in ("cuda", "xpu"): + print(f"torch.compile: SKIPPED (unsupported backend: {torch_platform_name})") + print("--- / COMPILE TEST ---") + return + + if importlib.util.find_spec("triton") is None: + print("triton: NOT INSTALLED") + else: + print("triton: installed") + + device = get_device(0) + print("On torch device:", device) + + def f(x): + return x * 2 + 1 + + try: + compiled_f = torch.compile(f) + x = torch.randn((1024,), device=device) + y = compiled_f(x) + expected = f(x) + if not torch.allclose(y, expected): + print("torch.compile: FAILED (output mismatch)") + else: + if torch_platform_name == "cuda": + torch.cuda.synchronize() + if torch_platform_name == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"): + torch.xpu.synchronize() + print("torch.compile: PASSED") + except Exception as e: + print(f"torch.compile: FAILED ({type(e).__name__}: {e})") + + hint = None + os_name = platform.system() + if torch_platform_name == "cuda" and os_name == "Windows": + hint = "pip install triton-windows (or: python -m torchruntime install)" + elif torch_platform_name == "cuda" and os_name == "Linux": + if getattr(torch.version, "hip", None): + hint = ( + "pip install pytorch-triton-rocm --index-url https://download.pytorch.org/whl " + "(or: python -m torchruntime install)" + ) + elif torch_platform_name == "xpu" and os_name == "Linux": + hint = ( + "pip install pytorch-triton-xpu --index-url https://download.pytorch.org/whl " + "(or: python -m torchruntime install)" + ) + + if hint: + print("If this failed due to Triton, try:") + print(" ", hint) + + print("--- / COMPILE TEST ---")