Skip to content

Commit 1a5eaec

Browse files
authored
Auto detect cuda in install script (#15027)
This pull request simplifies the process for building ExecuTorch with CUDA support by removing the requirement to set the `CMAKE_ARGS` environment variable. Now, CUDA support is automatically detected and handled during installation, streamlining both CI workflows and installation logic. Additionally, related documentation and error messages have been updated for clarity. **CI/CD and Installation Process Simplification** * Removed the need to set `CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON"` in CI scripts and workflow files; `install_executorch.sh` now handles CUDA detection automatically. (`.ci/scripts/test-cuda-build.sh`, `.github/workflows/cuda.yml`) [[1]](diffhunk://#diff-35136b86c3c720f3db97178cee3fda33da5704b4d700d4f82fb367f572f62a95L30-L32) [[2]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89L4-R4) [[3]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89L46-R46) [[4]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89L86-R86) [[5]](diffhunk://#diff-29abea04e0613c2569973e5c8e3c89e04846d408c855eeb1f3efcfae7cfa6f89L112-R112) * Deleted the `_is_cuda_enabled()` function from `install_utils.py` and refactored logic to rely solely on CUDA detection via `nvcc`. [[1]](diffhunk://#diff-0b2f5cd3e5a14317e108324dd2434d5367fc490e09449a3225c203ef277bf64cL9-L21) [[2]](diffhunk://#diff-0b2f5cd3e5a14317e108324dd2434d5367fc490e09449a3225c203ef277bf64cL108-R100) [[3]](diffhunk://#diff-0b2f5cd3e5a14317e108324dd2434d5367fc490e09449a3225c203ef277bf64cL118-R116) **Error Handling and Messaging Improvements** * Updated error messages in `install_utils.py` to remove references to "CUDA delegate" and clarify instructions for users when CUDA is not detected or supported. **Internal Refactoring** * Applied `functools.lru_cache` to `_get_cuda_version()` for more efficient repeated CUDA version detection. These changes make the CUDA build process more user-friendly and reduce the risk of misconfiguration.
1 parent a12219d commit 1a5eaec

File tree

3 files changed

+22
-29
lines changed

3 files changed

+22
-29
lines changed

.ci/scripts/test-cuda-build.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ test_executorch_cuda_build() {
2727
nvcc --version || echo "nvcc not found"
2828
nvidia-smi || echo "nvidia-smi not found"
2929

30-
# Set CMAKE_ARGS to enable CUDA build - ExecuTorch will handle PyTorch installation automatically
31-
export CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON"
32-
3330
echo "=== Starting ExecuTorch Installation ==="
3431
# Install ExecuTorch with CUDA support with timeout and error handling
3532
timeout 5400 ./install_executorch.sh || {

.github/workflows/cuda.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Test ExecuTorch CUDA Build Compatibility
22
# This workflow tests whether ExecuTorch can be successfully built with CUDA support
33
# across different CUDA versions (12.6, 12.8, 12.9) using the command:
4-
# CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh
4+
# ./install_executorch.sh
55
#
66
# Note: ExecuTorch automatically detects the system CUDA version using nvcc and
77
# installs the appropriate PyTorch wheel. No manual CUDA/PyTorch installation needed.
@@ -43,7 +43,7 @@ jobs:
4343
set -eux
4444
4545
# Test ExecuTorch CUDA build - ExecuTorch will automatically detect CUDA version
46-
# and install the appropriate PyTorch wheel when CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON"
46+
# and install the appropriate PyTorch wheel
4747
source .ci/scripts/test-cuda-build.sh "${{ matrix.cuda-version }}"
4848
4949
# This job will fail if any of the CUDA versions fail
@@ -83,7 +83,7 @@ jobs:
8383
script: |
8484
set -eux
8585
86-
PYTHON_EXECUTABLE=python CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh
86+
PYTHON_EXECUTABLE=python ./install_executorch.sh
8787
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
8888
PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda
8989
@@ -110,7 +110,7 @@ jobs:
110110
set -eux
111111
112112
echo "::group::Setup ExecuTorch"
113-
CMAKE_ARGS="-DEXECUTORCH_BUILD_CUDA=ON" ./install_executorch.sh
113+
./install_executorch.sh
114114
echo "::endgroup::"
115115
116116
echo "::group::Setup Huggingface"

install_utils.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,12 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import functools
9-
import os
109
import platform
1110
import re
1211
import subprocess
1312
import sys
1413

1514

16-
def _is_cuda_enabled():
17-
"""Check if CUDA delegate is enabled via CMAKE_ARGS environment variable."""
18-
cmake_args = os.environ.get("CMAKE_ARGS", "")
19-
return "-DEXECUTORCH_BUILD_CUDA=ON" in cmake_args
20-
21-
2215
def _cuda_version_to_pytorch_suffix(major, minor):
2316
"""
2417
Generate PyTorch CUDA wheel suffix from CUDA version numbers.
@@ -33,6 +26,7 @@ def _cuda_version_to_pytorch_suffix(major, minor):
3326
return f"cu{major}{minor}"
3427

3528

29+
@functools.lru_cache(maxsize=1)
3630
def _get_cuda_version(supported_cuda_versions):
3731
"""
3832
Get the CUDA version installed on the system using nvcc command.
@@ -62,25 +56,23 @@ def _get_cuda_version(supported_cuda_versions):
6256
)
6357
raise RuntimeError(
6458
f"Detected CUDA version {major}.{minor} is not supported. "
65-
f"Only the following CUDA versions are supported: {available_versions}. "
66-
f"Please install a supported CUDA version or try on CPU-only delegates."
59+
f"Supported versions: {available_versions}."
6760
)
6861

6962
return (major, minor)
7063
else:
7164
raise RuntimeError(
72-
"CUDA delegate is enabled but could not parse CUDA version from nvcc output. "
73-
"Please ensure CUDA is properly installed or try on CPU-only delegates."
65+
"Failed to parse CUDA version from nvcc output. "
66+
"Ensure CUDA is properly installed."
7467
)
7568
except FileNotFoundError:
7669
raise RuntimeError(
77-
"CUDA delegate is enabled but nvcc (CUDA compiler) is not found in PATH. "
78-
"Please install CUDA toolkit or try on CPU-only delegates."
70+
"nvcc (CUDA compiler) is not found in PATH. Install the CUDA toolkit."
7971
)
8072
except subprocess.CalledProcessError as e:
8173
raise RuntimeError(
82-
f"CUDA delegate is enabled but nvcc command failed with error: {e}. "
83-
"Please ensure CUDA is properly installed or try on CPU-only delegates."
74+
f"nvcc command failed with error: {e}. "
75+
"Ensure CUDA is properly installed."
8476
)
8577

8678

@@ -105,7 +97,7 @@ def _get_pytorch_cuda_url(cuda_version, torch_nightly_url_base):
10597
@functools.lru_cache(maxsize=1)
10698
def determine_torch_url(torch_nightly_url_base, supported_cuda_versions):
10799
"""
108-
Determine the appropriate PyTorch installation URL based on CUDA availability and CMAKE_ARGS.
100+
Determine the appropriate PyTorch installation URL based on CUDA availability.
109101
Uses @functools.lru_cache to avoid redundant CUDA detection and print statements.
110102
111103
Args:
@@ -115,15 +107,19 @@ def determine_torch_url(torch_nightly_url_base, supported_cuda_versions):
115107
Returns:
116108
URL string for PyTorch packages
117109
"""
118-
# Check if CUDA delegate is enabled
119-
if not _is_cuda_enabled():
120-
print("CUDA delegate not enabled, using CPU-only PyTorch")
110+
if platform.system().lower() == "windows":
111+
print(
112+
"Windows detected, using CPU-only PyTorch until CUDA support is available"
113+
)
121114
return f"{torch_nightly_url_base}/cpu"
122115

123-
print("CUDA delegate enabled, detecting CUDA version...")
116+
print("Attempting to detect CUDA via nvcc...")
124117

125-
# Get CUDA version
126-
cuda_version = _get_cuda_version(supported_cuda_versions)
118+
try:
119+
cuda_version = _get_cuda_version(supported_cuda_versions)
120+
except Exception as err:
121+
print(f"CUDA detection failed ({err}), using CPU-only PyTorch")
122+
return f"{torch_nightly_url_base}/cpu"
127123

128124
major, minor = cuda_version
129125
print(f"Detected CUDA version: {major}.{minor}")

0 commit comments

Comments
 (0)