diff --git a/API.md b/API.md index 7180da5..a26fb52 100644 --- a/API.md +++ b/API.md @@ -13,6 +13,8 @@ Or you can use the library: torchruntime.install(["torch", "torchvision<0.20"]) ``` +On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`). + ## Get device info You can use the device database built into `torchruntime` for your projects: ```py diff --git a/README.md b/README.md index c5e6c5f..6108c95 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ Supports Windows, Linux, and Mac. This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options. +On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`). + **Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv` ### Step 2. Configure torch diff --git a/setup.py b/setup.py index a55ca5f..c193e07 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ import setuptools setuptools.setup( - install_requires=[], + install_requires=["packaging"], ) diff --git a/tests/test_installer.py b/tests/test_installer.py index ade0358..ab8856f 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -16,20 +16,38 @@ def test_cpu_platform(): assert result == [packages] -def test_cuda_platform(): +def test_cuda_platform(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] result = get_install_commands("cu112", packages) expected_url = "https://download.pytorch.org/whl/cu112" assert result == [packages + ["--index-url", expected_url]] -def test_cuda_nightly_platform(): +def test_cuda_platform_windows_installs_triton(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Windows") + packages = ["torch", "torchvision"] + result = get_install_commands("cu112", packages) + expected_url = "https://download.pytorch.org/whl/cu112" + assert result == [packages + ["--index-url", expected_url], ["triton-windows"]] + + +def test_cuda_nightly_platform(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] result = get_install_commands("nightly/cu112", packages) expected_url = "https://download.pytorch.org/whl/nightly/cu112" assert result == [packages + ["--index-url", expected_url]] +def test_cuda_nightly_platform_windows_installs_triton(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Windows") + packages = ["torch", "torchvision"] + result = get_install_commands("nightly/cu112", packages) + expected_url = "https://download.pytorch.org/whl/nightly/cu112" + assert result == [packages + ["--index-url", expected_url], ["triton-windows"]] + + def test_rocm_platform(): packages = ["torch", "torchvision"] result = get_install_commands("rocm4.2", packages) @@ -37,6 +55,18 @@ def test_rocm_platform(): assert result == [packages + ["--index-url", expected_url]] +def test_rocm_platform_linux_installs_triton(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") + packages = ["torch", "torchvision"] + result = get_install_commands("rocm6.2", packages) + expected_url = "https://download.pytorch.org/whl/rocm6.2" + triton_index_url = "https://download.pytorch.org/whl" + assert result == [ + packages + ["--index-url", expected_url], + ["pytorch-triton-rocm", "--index-url", triton_index_url], + ] + + def test_xpu_platform_windows_with_torch_only(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Windows") packages = ["torch"] @@ -60,7 +90,11 @@ def test_xpu_platform_linux(monkeypatch): packages = ["torch", "torchvision"] result = get_install_commands("xpu", packages) expected_url = "https://download.pytorch.org/whl/test/xpu" - assert result == [packages + ["--index-url", expected_url]] + triton_index_url = "https://download.pytorch.org/whl" + assert result == [ + packages + ["--index-url", expected_url], + ["pytorch-triton-xpu", "--index-url", triton_index_url], + ] def test_directml_platform(): diff --git a/torchruntime/installer.py b/torchruntime/installer.py index d7f17d3..1153188 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -12,6 +12,7 @@ PIP_PREFIX = [sys.executable, "-m", "pip", "install"] CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") +ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P\d+)\.(?P\d+)$") def get_install_commands(torch_platform, packages): @@ -43,6 +44,9 @@ def get_install_commands(torch_platform, packages): - For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds. - For "directml", the "torch-directml" package is returned as part of the installation commands. - For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands. + - For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels). + - For Linux ROCm 6.x, the function also installs "pytorch-triton-rocm". + - For Linux XPU, the function also installs "pytorch-triton-xpu". """ if not packages: packages = ["torch", "torchaudio", "torchvision"] @@ -52,7 +56,17 @@ def get_install_commands(torch_platform, packages): if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform): index_url = f"https://download.pytorch.org/whl/{torch_platform}" - return [packages + ["--index-url", index_url]] + cmds = [packages + ["--index-url", index_url]] + + if os_name == "Windows" and CUDA_REGEX.match(torch_platform): + cmds.append(["triton-windows"]) + + if os_name == "Linux" and ROCM_REGEX.match(torch_platform): + match = ROCM_VERSION_REGEX.match(torch_platform) + if match and int(match.group("major")) >= 6: + cmds.append(["pytorch-triton-rocm", "--index-url", "https://download.pytorch.org/whl"]) + + return cmds if torch_platform == "xpu": if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages): @@ -65,7 +79,10 @@ def get_install_commands(torch_platform, packages): else: index_url = f"https://download.pytorch.org/whl/test/{torch_platform}" - return [packages + ["--index-url", index_url]] + cmds = [packages + ["--index-url", index_url]] + if os_name == "Linux": + cmds.append(["pytorch-triton-xpu", "--index-url", "https://download.pytorch.org/whl"]) + return cmds if torch_platform == "directml": return [["torch-directml"], packages]