From fdaa6c6c54f3015b24bd54c15263ac7f5c457762 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 6 Oct 2025 13:25:18 -0700 Subject: [PATCH] Support L2 cache clearing in do_bench_cudagraph --- python/test/unit/runtime/test_autotuner.py | 16 ++++--- python/triton/testing.py | 56 ++++++++++++++++++++-- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index d9b972d6bfd2..90bb4f8790ec 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -9,14 +9,15 @@ from triton._internal_testing import is_cuda -def do_bench(kernel_call, quantiles, use_cuda_graph=False): +def do_bench(kernel_call, quantiles, use_cuda_graph=False, clear_cache=False): if use_cuda_graph: - return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles) + return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles, clear_cache=clear_cache) + assert not clear_cache, "clear_cache arg is only meaningful with do_bench_cudagraph" return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) -@pytest.mark.parametrize('use_cuda_graph', [False, True]) -def test_kwargs(use_cuda_graph: bool, device: str): +@pytest.mark.parametrize('use_cuda_graph,clear_cache', [(False, False), (True, False), (True, True)]) +def test_kwargs(use_cuda_graph: bool, clear_cache: bool, device: str): if use_cuda_graph and not torch.cuda.is_available(): pytest.xfail("CUDA is not available") @@ -26,8 +27,11 @@ def test_kwargs(use_cuda_graph: bool, device: str): configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] - @triton.autotune(configs=configs, key=["M"], - do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph)) + @triton.autotune( + configs=configs, + key=["M"], + do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph, clear_cache), + ) @triton.jit def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) diff --git a/python/triton/testing.py b/python/triton/testing.py index 6450e3ab7308..dc3a0e5abc96 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -57,7 +57,14 @@ def _summarize_statistics(times, quantiles, return_mode): return statistics.median(times) -def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): +def do_bench_cudagraph( + fn, + rep=20, + grad_to_none=None, + quantiles=None, + return_mode="mean", + clear_cache=False, +): """ Benchmark the runtime of the provided function. @@ -69,12 +76,22 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod :type grad_to_none: torch.tensor, optional :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". :type return_mode: str + :param clear_cache: If True, zero the benchmarking cache tensor before each + invocation of `fn` to mimic the behaviour of `do_bench` with explicit L2 + cache flushes. Default is False. """ import torch assert return_mode in ["min", "max", "mean", "median", "all"] + cache = (runtime.driver.active.get_empty_cache_for_benchmark() if clear_cache else None) + + def maybe_clear_cache(): + if cache is not None: + cache.zero_() + with torch.cuda.stream(torch.cuda.Stream()): # warmup + maybe_clear_cache() fn() if grad_to_none is not None: for x in grad_to_none: @@ -91,6 +108,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(5): + maybe_clear_cache() fn() end_event.record() torch.cuda.synchronize() @@ -108,19 +126,51 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod if grad_to_none is not None: for x in grad_to_none: x.grad = None + maybe_clear_cache() fn() torch.cuda.synchronize() + + # step 3 - if cache clearing is enabled, create a separate graph to measure cache clearing overhead + cache_clear_graph = None + if clear_cache: + cache_clear_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cache_clear_graph): + for _ in range(n_repeat): + maybe_clear_cache() + torch.cuda.synchronize() + # measure time and return - ret = [] n_retries = 10 + cache_clear_times = [] + total_times = [] for _ in range(n_retries): + # measure cache clear time if applicable + if cache_clear_graph is not None: + cache_clear_start_event = torch.cuda.Event(enable_timing=True) + cache_clear_end_event = torch.cuda.Event(enable_timing=True) + cache_clear_start_event.record() + cache_clear_graph.replay() + cache_clear_end_event.record() + torch.cuda.synchronize() + cache_clear_times.append(cache_clear_start_event.elapsed_time(cache_clear_end_event) / n_repeat) + + # measure total time start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() g.replay() end_event.record() torch.cuda.synchronize() - ret += [start_event.elapsed_time(end_event) / n_repeat] + total_times.append(start_event.elapsed_time(end_event) / n_repeat) + + # subtract cache clear overhead if applicable + if clear_cache: + ret = [ + total_time - cache_clear_time for total_time, cache_clear_time in zip(total_times, cache_clear_times) + ] + else: + ret = total_times + return _summarize_statistics(ret, quantiles, return_mode)