Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/test/unit/runtime/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from triton._internal_testing import is_cuda


def do_bench(kernel_call, quantiles, use_cuda_graph=False):
def do_bench(kernel_call, quantiles, use_cuda_graph=False, clear_cache=False):
if use_cuda_graph:
return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles)
return triton.testing.do_bench_cudagraph(kernel_call, quantiles=quantiles, clear_cache=clear_cache)
assert not clear_cache, "clear_cache arg is only meaningful with do_bench_cudagraph"
return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1)


@pytest.mark.parametrize('use_cuda_graph', [False, True])
def test_kwargs(use_cuda_graph: bool, device: str):
@pytest.mark.parametrize('use_cuda_graph,clear_cache', [(False, False), (True, False), (True, True)])
def test_kwargs(use_cuda_graph: bool, clear_cache: bool, device: str):
if use_cuda_graph and not torch.cuda.is_available():
pytest.xfail("CUDA is not available")

Expand All @@ -26,8 +27,11 @@ def test_kwargs(use_cuda_graph: bool, device: str):

configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})]

@triton.autotune(configs=configs, key=["M"],
do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph))
@triton.autotune(
configs=configs,
key=["M"],
do_bench=lambda kernel, quantiles: do_bench(kernel, quantiles, use_cuda_graph, clear_cache),
)
@triton.jit
def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr):
offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M)
Expand Down
21 changes: 20 additions & 1 deletion python/triton/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ def _summarize_statistics(times, quantiles, return_mode):
return statistics.median(times)


def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
def do_bench_cudagraph(
fn,
rep=20,
grad_to_none=None,
quantiles=None,
return_mode="mean",
clear_cache=False,
):
"""
Benchmark the runtime of the provided function.

Expand All @@ -69,12 +76,22 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
:type grad_to_none: torch.tensor, optional
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
:type return_mode: str
:param clear_cache: If True, zero the benchmarking cache tensor before each
invocation of `fn` to mimic the behaviour of `do_bench` with explicit L2
cache flushes. Default is False.
"""
import torch
assert return_mode in ["min", "max", "mean", "median", "all"]

cache = (runtime.driver.active.get_empty_cache_for_benchmark() if clear_cache else None)

def maybe_clear_cache():
if cache is not None:
cache.zero_()

with torch.cuda.stream(torch.cuda.Stream()):
# warmup
maybe_clear_cache()
fn()
if grad_to_none is not None:
for x in grad_to_none:
Expand All @@ -91,6 +108,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
maybe_clear_cache()
fn()
end_event.record()
torch.cuda.synchronize()
Expand All @@ -108,6 +126,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
maybe_clear_cache()
fn()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to have the cache flushing time included in the benchmark measurement. I suppose for autotuning purposes it should be fine as all calls are effected the same way, but there should at least be a warning in the doc string.

torch.cuda.synchronize()
# measure time and return
Expand Down