From 2eb490ca622986bdcf44ba27b3f11df546ea25d3 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 7 Oct 2025 14:57:51 -0700 Subject: [PATCH] Exclude L2 cache clear time from timing measurement --- tritonbench/components/do_bench/run.py | 28 +++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py index be681428b..e93f37d55 100644 --- a/tritonbench/components/do_bench/run.py +++ b/tritonbench/components/do_bench/run.py @@ -205,18 +205,40 @@ def _do_bench_cudagraph_with_cache_clear( fn() torch.cuda.synchronize() - ret = [] + cache_clear_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cache_clear_graph): + for _ in range(n_repeat): + cache.zero_() + torch.cuda.synchronize() + n_retries = 10 + cache_clear_times = [] + total_times = [] for _ in range(n_retries): + cache_clear_start_event = torch.cuda.Event(enable_timing=True) + cache_clear_end_event = torch.cuda.Event(enable_timing=True) + cache_clear_start_event.record() + cache_clear_graph.replay() + cache_clear_end_event.record() + torch.cuda.synchronize() + cache_clear_times.append( + cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat + ) + start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() g.replay() end_event.record() torch.cuda.synchronize() - ret.append(start_event.elapsed_time(end_event) / n_repeat) + total_times.append(start_event.elapsed_time(end_event) / n_repeat) - times = torch.tensor(ret, dtype=torch.float) + all_kernel_times = [] + for total_time, cache_clear_time in zip(total_times, cache_clear_times): + kernel_time = total_time - cache_clear_time + all_kernel_times.append(kernel_time) + + times = torch.tensor(all_kernel_times, dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode)