Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion tritonbench/components/do_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,60 @@ def _do_bench_inductor(fn, warmup, rep, return_mode="all", grad_to_none=None):
return _summarize_statistics(times, quantiles=None, return_mode=return_mode)


def _do_bench_cudagraph_with_cache_clear(
fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"
):
"""Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing."""
assert return_mode in ["min", "max", "mean", "median", "all"]

cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()

with torch.cuda.stream(torch.cuda.Stream()):
cache.zero_()
fn()
if grad_to_none is not None:
for x in grad_to_none:
x.detach_()
x.requires_grad_(True)
x.grad = None

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms))

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for _ in range(n_repeat):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
cache.zero_()
fn()
torch.cuda.synchronize()

ret = []
n_retries = 10
for _ in range(n_retries):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
g.replay()
end_event.record()
torch.cuda.synchronize()
ret.append(start_event.elapsed_time(end_event) / n_repeat)

times = torch.tensor(ret, dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


def _do_bench_profiler(
fn, warmup, rep, return_mode="all", grad_to_none=None, use_cudagraph=False
):
Expand Down Expand Up @@ -383,7 +437,7 @@ def do_bench_wrapper(
if latency_measure_mode == "profiler":
bench_fn = partial(_do_bench_profiler, warmup=1, use_cudagraph=True)
else:
bench_fn = triton.testing.do_bench_cudagraph
bench_fn = _do_bench_cudagraph_with_cache_clear

return Latency(
times=bench_fn(
Expand Down