From f08bdee130e0a32c359704c522f321bf9faa4128 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 4 Oct 2025 12:06:43 -0700 Subject: [PATCH] Add L2 cache clearing to do_bench_cudagraph, for more realistic timing measurement --- tritonbench/components/do_bench/run.py | 56 +++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py index 3fbfddec3..be681428b 100644 --- a/tritonbench/components/do_bench/run.py +++ b/tritonbench/components/do_bench/run.py @@ -166,6 +166,60 @@ def _do_bench_inductor(fn, warmup, rep, return_mode="all", grad_to_none=None): return _summarize_statistics(times, quantiles=None, return_mode=return_mode) +def _do_bench_cudagraph_with_cache_clear( + fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean" +): + """Clone of triton.testing.do_bench_cudagraph with explicit L2 cache clearing.""" + assert return_mode in ["min", "max", "mean", "median", "all"] + + cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() + + with torch.cuda.stream(torch.cuda.Stream()): + cache.zero_() + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + n_repeat = 1000 if estimate_ms == 0 else max(1, int(rep / estimate_ms)) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + cache.zero_() + fn() + torch.cuda.synchronize() + + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret.append(start_event.elapsed_time(end_event) / n_repeat) + + times = torch.tensor(ret, dtype=torch.float) + return _summarize_statistics(times, quantiles, return_mode) + + def _do_bench_profiler( fn, warmup, rep, return_mode="all", grad_to_none=None, use_cudagraph=False ): @@ -383,7 +437,7 @@ def do_bench_wrapper( if latency_measure_mode == "profiler": bench_fn = partial(_do_bench_profiler, warmup=1, use_cudagraph=True) else: - bench_fn = triton.testing.do_bench_cudagraph + bench_fn = _do_bench_cudagraph_with_cache_clear return Latency( times=bench_fn(