Skip to content

Commit 1ab4208

Browse files
authored
[Benchmark] Fix tritonbench auto-installation (#980)
1 parent 944e7a8 commit 1ab4208

File tree

1 file changed

+141
-92
lines changed

1 file changed

+141
-92
lines changed

benchmarks/run.py

Lines changed: 141 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import argparse
2525
import collections
26+
from contextlib import suppress
2627
import dataclasses
2728
import functools
2829
import gc
@@ -32,31 +33,50 @@
3233
import os
3334
from pathlib import Path
3435
from pprint import pformat
36+
import shutil
3537
import subprocess
3638
import sys
3739
import tempfile
38-
from typing import TYPE_CHECKING
3940
from typing import Any
4041
from typing import Callable
42+
from typing import cast
4143

4244
import torch
4345
from torch.utils._pytree import tree_leaves
4446
from torch.utils._pytree import tree_map
4547

4648
from helion._utils import counters
4749

48-
if TYPE_CHECKING:
49-
from tritonbench.utils.triton_op import BenchmarkOperator
50-
from tritonbench.utils.triton_op import BenchmarkOperatorMetrics
50+
logger: logging.Logger = logging.getLogger(__name__)
5151

52-
try:
53-
from tritonbench.utils.env_utils import get_nvidia_gpu_model
54-
from tritonbench.utils.env_utils import is_cuda
52+
StrPath = str | os.PathLike[str]
5553

56-
IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"
57-
except ImportError:
58-
print("Failed B200 detection since tritonbench is not installed (yet)")
59-
IS_B200 = False
54+
if os.getenv("HELION_BENCHMARK_DISABLE_LOGGING", "0") == "1":
55+
logging.disable(logging.CRITICAL)
56+
57+
58+
def is_cuda() -> bool:
59+
return torch.version.cuda is not None
60+
61+
62+
def get_nvidia_gpu_model() -> str:
63+
"""
64+
Retrieves the model of the NVIDIA GPU being used.
65+
Will return the name of the first GPU listed.
66+
Returns:
67+
str: The model of the NVIDIA GPU or empty str if not found.
68+
"""
69+
try:
70+
model = subprocess.check_output(
71+
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader,nounits"]
72+
)
73+
return model.decode().strip().split("\n")[0]
74+
except OSError:
75+
logger.warning("nvidia-smi not found. Returning empty str.")
76+
return ""
77+
78+
79+
IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"
6080

6181

6282
def log_tensor_metadata(args: tuple[object, ...], kwargs: dict[str, object]) -> None:
@@ -82,11 +102,6 @@ def describe_tensor(obj: object) -> object:
82102
)
83103

84104

85-
logger: logging.Logger = logging.getLogger(__name__)
86-
87-
if os.getenv("HELION_BENCHMARK_DISABLE_LOGGING", "0") == "1":
88-
logging.disable(logging.CRITICAL)
89-
90105
# Maximum number of inputs to use
91106
MAX_NUM_INPUTS = 20
92107

@@ -600,109 +615,141 @@ class RunResult:
600615
}
601616

602617

603-
def get_system_memory_gb() -> float:
604-
"""Get system memory in GB."""
605-
try:
606-
# Try to read from /proc/meminfo on Linux
607-
meminfo_path = Path("/proc/meminfo")
608-
if meminfo_path.exists():
609-
with open(meminfo_path) as f:
610-
for line in f:
611-
if line.startswith("MemTotal:"):
612-
# Extract memory in kB and convert to GB
613-
mem_kb = int(line.split()[1])
614-
return mem_kb / (1024 * 1024)
615-
616-
# Fallback: use psutil if available
617-
try:
618-
import psutil
618+
def check_and_setup_tritonbench() -> None:
619+
"""Ensure a usable tritonbench installation is available."""
619620

620-
return psutil.virtual_memory().total / (1024**3)
621-
except ImportError:
622-
pass
621+
benchmarks_dir = Path(__file__).parent
622+
tritonbench_path = benchmarks_dir / "tritonbench"
623+
installing_marker = (benchmarks_dir / ".tritonbench_installing").resolve()
623624

624-
except Exception:
625-
pass
625+
try:
626+
import tritonbench # pyright: ignore[reportMissingImports]
627+
628+
module_file = getattr(tritonbench, "__file__", None)
629+
tb_repo_path = tritonbench_path.resolve()
630+
631+
candidate_paths: list[Path] = []
632+
633+
def add_candidate_path(entry: object) -> None:
634+
if not isinstance(entry, (str, os.PathLike)):
635+
return
636+
path_entry = cast("StrPath", entry)
637+
with suppress(TypeError, OSError, RuntimeError):
638+
candidate_paths.append(Path(path_entry))
639+
640+
if module_file is not None:
641+
add_candidate_path(module_file)
642+
643+
module_paths = getattr(tritonbench, "__path__", None)
644+
if module_paths is not None:
645+
for entry in module_paths:
646+
add_candidate_path(entry)
647+
648+
def is_local(path: Path) -> bool:
649+
try:
650+
resolved_path = path.resolve()
651+
except (OSError, RuntimeError):
652+
return False
653+
return (
654+
resolved_path == tb_repo_path or tb_repo_path in resolved_path.parents
655+
)
626656

627-
# Default to assuming high memory if we can't detect
628-
return 32.0
657+
has_local_checkout = any(is_local(path) for path in candidate_paths)
629658

659+
if candidate_paths and not has_local_checkout:
660+
# If tritonbench is not from local checkout, assume it's a proper installation
661+
return
630662

631-
def check_and_setup_tritonbench() -> None:
632-
"""Check if tritonbench is installed and install it from GitHub if not."""
633-
# Check if tritonbench is already installed
634-
if importlib.util.find_spec("tritonbench") is not None:
635-
return # Already installed
663+
if has_local_checkout:
664+
if installing_marker.exists():
665+
print(
666+
"Detected partially installed tritonbench; reinstalling local checkout.",
667+
file=sys.stderr,
668+
)
669+
else:
670+
return
671+
else:
672+
print(
673+
"Unable to determine tritonbench import path; reinstalling local checkout.",
674+
file=sys.stderr,
675+
)
636676

637-
print("Tritonbench not found. Installing...", file=sys.stderr)
677+
except ImportError:
678+
pass
638679

639-
# Clone to benchmarks/tritonbench
640-
benchmarks_dir = Path(__file__).parent
641-
tritonbench_path = benchmarks_dir / "tritonbench"
680+
print(
681+
"Installing tritonbench from source...",
682+
file=sys.stderr,
683+
)
642684
print(f"Using tritonbench path: {tritonbench_path}")
643685

644-
try:
645-
# Clone the repository if it doesn't exist
646-
if not tritonbench_path.exists():
647-
print("Cloning tritonbench repository...", file=sys.stderr)
648-
subprocess.run(
649-
[
650-
"git",
651-
"clone",
652-
"https://github.com/meta-pytorch/tritonbench.git",
653-
str(tritonbench_path),
654-
],
655-
check=True,
656-
)
686+
if tritonbench_path.exists():
687+
print("Removing existing tritonbench checkout...", file=sys.stderr)
688+
if tritonbench_path.is_dir():
689+
shutil.rmtree(tritonbench_path)
690+
else:
691+
tritonbench_path.unlink()
657692

658-
# Initialize submodules
659-
print("Initializing tritonbench's submodules...", file=sys.stderr)
660-
subprocess.run(
661-
["git", "submodule", "update", "--init", "--recursive"],
662-
cwd=tritonbench_path,
663-
check=True,
664-
)
693+
sys.modules.pop("tritonbench", None)
665694

666-
# Detect system memory and choose install flags.
667-
# Low-memory systems can freeze when building dependencies like flash-attn,
668-
# so we only install the Liger library in that case.
669-
memory_gb = get_system_memory_gb()
670-
install_flag = "--liger" if memory_gb < 16 else "--all"
695+
installing_marker.touch()
671696

672-
# Install optional dependencies for tritonbench
673-
print(
674-
f"Running install.py {install_flag} (detected {memory_gb:.1f}GB system RAM)...",
675-
file=sys.stderr,
697+
try:
698+
print("Cloning tritonbench repository...", file=sys.stderr)
699+
subprocess.run(
700+
[
701+
"git",
702+
"clone",
703+
"https://github.com/meta-pytorch/tritonbench.git",
704+
str(tritonbench_path),
705+
],
706+
cwd=benchmarks_dir,
707+
check=True,
676708
)
677-
env = os.environ.copy()
678-
if install_flag == "--all":
679-
# Set max jobs to 4 to avoid OOM
680-
env["MAX_JOBS"] = "4"
709+
710+
print("Initializing tritonbench submodules...", file=sys.stderr)
681711
subprocess.run(
682-
[sys.executable, "install.py", install_flag],
712+
["git", "submodule", "update", "--init", "--recursive"],
713+
cwd=tritonbench_path,
714+
check=True,
715+
)
716+
717+
print("Installing tritonbench requirements...", file=sys.stderr)
718+
subprocess.run(
719+
[
720+
sys.executable,
721+
"-m",
722+
"pip",
723+
"install",
724+
"-r",
725+
"requirements.txt",
726+
],
727+
cwd=tritonbench_path,
728+
check=True,
729+
)
730+
731+
print("Running install.py --liger...", file=sys.stderr)
732+
subprocess.run(
733+
[sys.executable, "install.py", "--liger"],
683734
cwd=tritonbench_path,
684735
check=True,
685-
env=env,
686736
)
687737

688-
# Install tritonbench package
689738
print("Installing tritonbench package...", file=sys.stderr)
690739
subprocess.run(
691-
[sys.executable, "-m", "pip", "install", "-e", str(tritonbench_path)],
740+
[sys.executable, "-m", "pip", "install", "-e", "."],
741+
cwd=tritonbench_path,
692742
check=True,
693743
)
694744

695-
# Invalidate import caches to recognize newly installed package
696745
importlib.invalidate_caches()
697746

698-
# Verify installation worked
699747
try:
700-
import tritonbench # noqa: F401 # pyright: ignore[reportMissingImports]
748+
import tritonbench # pyright: ignore[reportMissingImports]
701749

702-
print(
703-
f"Tritonbench installed successfully with {install_flag}.",
704-
file=sys.stderr,
705-
)
750+
print("Tritonbench installed successfully.", file=sys.stderr)
751+
if installing_marker.exists():
752+
installing_marker.unlink()
706753
except ImportError:
707754
print(
708755
"Error: Tritonbench package installation failed. The package cannot be imported.",
@@ -789,6 +836,8 @@ def run_kernel_variants(
789836
from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]
790837
get_parser,
791838
)
839+
from tritonbench.utils.triton_op import BenchmarkOperator
840+
from tritonbench.utils.triton_op import BenchmarkOperatorMetrics
792841

793842
# Get the tritonbench operator name, stripping -bwd suffix for backward operators
794843
operator_name = kernel_name.removesuffix("-bwd")

0 commit comments

Comments
 (0)