diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index d9b972d6bfd2..90bb4f8790ec 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -9,14 +9,15 @@ from triton._internal_testing import is_cuda -def do_bench(kernel_call, quantiles, use_cuda_graph=False): +def do_bench(kernel_call, quantiles, use_cuda_graph=False, clear_cache=False): if use_cuda_graph: - return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles) + return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles, clear_cache=clear_cache) + assert not clear_cache, "clear_cache arg is only meaningful with do_bench_cudagraph" return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) -@pytest.mark.parametrize('use_cuda_graph', [False, True]) -def test_kwargs(use_cuda_graph: bool, device: str): +@pytest.mark.parametrize('use_cuda_graph,clear_cache', [(False, False), (True, False), (True, True)]) +def test_kwargs(use_cuda_graph: bool, clear_cache: bool, device: str): if use_cuda_graph and not torch.cuda.is_available(): pytest.xfail("CUDA is not available") @@ -26,8 +27,11 @@ def test_kwargs(use_cuda_graph: bool, device: str): configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] - @triton.autotune(configs=configs, key=["M"], - do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph)) + @triton.autotune( + configs=configs, + key=["M"], + do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph, clear_cache), + ) @triton.jit def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) diff --git a/python/triton/testing.py b/python/triton/testing.py index 1fd88c648779..a009d8c70793 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -57,7 +57,14 @@ def _summarize_statistics(times, quantiles, return_mode): return statistics.median(times) -def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): +def do_bench_cudagraph( + fn, + rep=20, + grad_to_none=None, + quantiles=None, + return_mode="mean", + clear_cache=False, +): """ Benchmark the runtime of the provided function. @@ -69,12 +76,22 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod :type grad_to_none: torch.tensor, optional :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean". :type return_mode: str + :param clear_cache: If True, zero the benchmarking cache tensor before each + invocation of `fn` to mimic the behaviour of `do_bench` with explicit L2 + cache flushes. Default is False. """ import torch assert return_mode in ["min", "max", "mean", "median", "all"] + cache = (runtime.driver.active.get_empty_cache_for_benchmark() if clear_cache else None) + + def maybe_clear_cache(): + if cache is not None: + cache.zero_() + with torch.cuda.stream(torch.cuda.Stream()): # warmup + maybe_clear_cache() fn() if grad_to_none is not None: for x in grad_to_none: @@ -91,6 +108,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(5): + maybe_clear_cache() fn() end_event.record() torch.cuda.synchronize() @@ -108,6 +126,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod if grad_to_none is not None: for x in grad_to_none: x.grad = None + maybe_clear_cache() fn() torch.cuda.synchronize() # measure time and return