From 03ebfc2b69fade42682c5c37461f2c1a1c665f61 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 7 Nov 2025 12:38:16 -0800 Subject: [PATCH] Fix rotary embedding benchmark script Signed-off-by: Xin Yang --- benchmarks/kernels/benchmark_rope.py | 154 +++++++++++---------------- 1 file changed, 64 insertions(+), 90 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 29ef6409bb16..074b7a440b61 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,97 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from itertools import accumulate +import itertools -import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope -from vllm.platforms import current_platform +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser +batch_size_range = [2**i for i in range(0, 8, 2)] +seq_len_range = [2**i for i in range(6, 10, 1)] +num_heads_range = [32, 48] +configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range)) -def benchmark_rope_kernels_multi_lora( - is_neox_style: bool, - batch_size: int, - seq_len: int, - num_heads: int, - head_size: int, - rotary_dim: int | None, - dtype: torch.dtype, - seed: int, - device: str, - max_position: int = 8192, - base: float = 10000, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - if rotary_dim is None: - rotary_dim = head_size - # silulating serving 4 LoRAs - scaling_factors = [1, 2, 4, 8] - # batched RoPE can take multiple scaling factors - batched_rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": tuple(scaling_factors)}, + +def get_benchmark(head_size, rotary_dim, is_neox_style, device): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "num_heads"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "flashinfer", "vllm"], + line_names=["PyTorch", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}", + args={}, + ) ) - # non-batched RoPE takes only one scaling factor, we create multiple - # instances to simulate the same behavior - non_batched_ropes: list[RotaryEmbedding] = [] - for scaling_factor in scaling_factors: - non_batched_ropes.append( - get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - {"rope_type": "linear", "factor": (scaling_factor,)}, - ) + def benchmark(batch_size, seq_len, num_heads, provider): + dtype = torch.bfloat16 + max_position = 8192 + base = 10000 + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) + rope = rope.to(dtype=dtype, device=device) + cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device) + + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + query = torch.randn( + (batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device ) + key = torch.randn_like(query) - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + quantiles = [0.5, 0.2, 0.8] - # create query offsets for batched RoPE, we concat multiple kv cache - # together and each query needs to find the right kv cache of its type - offset_map = torch.tensor( - list( - accumulate( - [0] - + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ] + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_native(positions, query.clone(), key.clone()), + quantiles=quantiles, ) - ) - ) - query_types = torch.randint( - 0, len(scaling_factors), (batch_size, seq_len), device=device - ) - # map query types to offsets - query_offsets = offset_map[query_types] - # the kernel takes flattened offsets - flatten_offsets = query_offsets.flatten() + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query.clone(), + key.clone(), + head_size, + cos_sin_cache, + is_neox_style, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rope.forward_cuda(positions, query.clone(), key.clone()), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms - # batched queries of the same type together for non-batched RoPE - queries = [query[query_types == i] for i in range(len(scaling_factors))] - keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qkr = zip(queries, keys, non_batched_ropes) - # synchronize before start timing - torch.cuda.synchronize() - with nvtx.annotate("non-batched", color="yellow"): - for q, k, r in packed_qkr: - r.forward(positions, q, k) - torch.cuda.synchronize() - with nvtx.annotate("batched", color="green"): - batched_rope.forward(positions, query, key, flatten_offsets) - torch.cuda.synchronize() + return benchmark if __name__ == "__main__": @@ -116,17 +95,12 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument( "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" ) + parser.add_argument("--save-path", type=str, default="./configs/rope/") args = parser.parse_args() - print(args) - benchmark_rope_kernels_multi_lora( - is_neox_style=args.is_neox_style, - batch_size=args.batch_size, - seq_len=args.seq_len, - num_heads=args.num_heads, - head_size=args.head_size, - rotary_dim=args.rotary_dim, - dtype=getattr(torch, args.dtype), - seed=args.seed, - device=args.device, + # Get the benchmark function + benchmark = get_benchmark( + args.head_size, args.rotary_dim, args.is_neox_style, args.device ) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path)