Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions API.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Or you can use the library:
torchruntime.install(["torch", "torchvision<0.20"])
```

On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).

## Get device info
You can use the device database built into `torchruntime` for your projects:
```py
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Supports Windows, Linux, and Mac.

This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.

On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).

**Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv`

### Step 2. Configure torch
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import setuptools

setuptools.setup(
install_requires=[],
install_requires=["packaging"],
)
40 changes: 37 additions & 3 deletions tests/test_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,57 @@ def test_cpu_platform():
assert result == [packages]


def test_cuda_platform():
def test_cuda_platform(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("cu112", packages)
expected_url = "https://download.pytorch.org/whl/cu112"
assert result == [packages + ["--index-url", expected_url]]


def test_cuda_nightly_platform():
def test_cuda_platform_windows_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("cu112", packages)
expected_url = "https://download.pytorch.org/whl/cu112"
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]


def test_cuda_nightly_platform(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("nightly/cu112", packages)
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
assert result == [packages + ["--index-url", expected_url]]


def test_cuda_nightly_platform_windows_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("nightly/cu112", packages)
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]


def test_rocm_platform():
packages = ["torch", "torchvision"]
result = get_install_commands("rocm4.2", packages)
expected_url = "https://download.pytorch.org/whl/rocm4.2"
assert result == [packages + ["--index-url", expected_url]]


def test_rocm_platform_linux_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("rocm6.2", packages)
expected_url = "https://download.pytorch.org/whl/rocm6.2"
triton_index_url = "https://download.pytorch.org/whl"
assert result == [
packages + ["--index-url", expected_url],
["pytorch-triton-rocm", "--index-url", triton_index_url],
]


def test_xpu_platform_windows_with_torch_only(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch"]
Expand All @@ -60,7 +90,11 @@ def test_xpu_platform_linux(monkeypatch):
packages = ["torch", "torchvision"]
result = get_install_commands("xpu", packages)
expected_url = "https://download.pytorch.org/whl/test/xpu"
assert result == [packages + ["--index-url", expected_url]]
triton_index_url = "https://download.pytorch.org/whl"
assert result == [
packages + ["--index-url", expected_url],
["pytorch-triton-xpu", "--index-url", triton_index_url],
]


def test_directml_platform():
Expand Down
21 changes: 19 additions & 2 deletions torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")
ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P<major>\d+)\.(?P<minor>\d+)$")


def get_install_commands(torch_platform, packages):
Expand Down Expand Up @@ -43,6 +44,9 @@ def get_install_commands(torch_platform, packages):
- For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
- For "directml", the "torch-directml" package is returned as part of the installation commands.
- For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
- For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels).
- For Linux ROCm 6.x, the function also installs "pytorch-triton-rocm".
- For Linux XPU, the function also installs "pytorch-triton-xpu".
"""
if not packages:
packages = ["torch", "torchaudio", "torchvision"]
Expand All @@ -52,7 +56,17 @@ def get_install_commands(torch_platform, packages):

if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform):
index_url = f"https://download.pytorch.org/whl/{torch_platform}"
return [packages + ["--index-url", index_url]]
cmds = [packages + ["--index-url", index_url]]

if os_name == "Windows" and CUDA_REGEX.match(torch_platform):
cmds.append(["triton-windows"])

if os_name == "Linux" and ROCM_REGEX.match(torch_platform):
match = ROCM_VERSION_REGEX.match(torch_platform)
if match and int(match.group("major")) >= 6:
cmds.append(["pytorch-triton-rocm", "--index-url", "https://download.pytorch.org/whl"])

return cmds

if torch_platform == "xpu":
if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
Expand All @@ -65,7 +79,10 @@ def get_install_commands(torch_platform, packages):
else:
index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"

return [packages + ["--index-url", index_url]]
cmds = [packages + ["--index-url", index_url]]
if os_name == "Linux":
cmds.append(["pytorch-triton-xpu", "--index-url", "https://download.pytorch.org/whl"])
return cmds

if torch_platform == "directml":
return [["torch-directml"], packages]
Expand Down