diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 945276376d66..0f3320cbf96d 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -73,7 +73,7 @@ def test_without_spec_decoding( run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): +def test_with_eagle3_spec_decoding(monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, spec decoding model length. @@ -111,6 +111,42 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params) +def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test ngram_gpu speculative decoding with different configurations. + + This test specifically validates ngram_gpu behavior with various: + - Number of speculative tokens (2-6) + - Prompt lookup window sizes (min/max) + - Async scheduling enabled (as in production) + - Different executors and chunking settings + """ + + # Variant with larger speculation window + ngram_gpu_config = { + "method": "ngram_gpu", + "num_speculative_tokens": 3, + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + } + + # Test configurations covering various scenarios + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", False, ngram_gpu_config, False), + (True, "mp", False, ngram_gpu_config, True), + (False, "mp", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, False), + (True, "uni", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, True), + ] + + # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight + # and ngram_gpu doesn't require a specific draft model + run_tests(monkeypatch, MODEL, test_configs, [{}]) + + @dynamo_config.patch(cache_size_limit=16) def run_tests( monkeypatch: pytest.MonkeyPatch, @@ -222,18 +258,19 @@ def run_test( else dict(gpu_memory_utilization=0.9) ) spec_mml = (spec_config or {}).get("max_model_len") + spec_method = (spec_config or {}).get("method", "none") test_config = ( f"executor={executor}, preemption={test_preemption}, " f"async_sched={async_scheduling}, " f"chunk_prefill={test_prefill_chunking}, " - f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" + f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}" ) print("-" * 80) print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) with VllmRunner( model, - max_model_len=512, + max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1773913d0b6c..f8831d02a4df 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -649,6 +649,13 @@ def __call__( # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE. disable_cache = not is_compile_cache_enabled(self.inductor_config) + # TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors. + is_ngram_gpu_enabled = ( + vllm_config.speculative_config + and vllm_config.speculative_config.method == "ngram_gpu" + ) + disable_cache = disable_cache or is_ngram_gpu_enabled + if disable_cache: logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") else: diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 80d53a543f14..dc22a87eacf5 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -39,6 +39,7 @@ "pangu_ultra_moe_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +NgramGPUTypes = Literal["ngram_gpu"] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -46,6 +47,7 @@ "draft_model", "suffix", EagleModelTypes, + NgramGPUTypes, ] @@ -259,6 +261,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "ngram_gpu": + self.model = "ngram_gpu" elif self.method == "suffix": self.model = "suffix" else: @@ -273,9 +277,10 @@ def __post_init__(self): ): self.method = "ngram" - if self.method in ("ngram", "[ngram]"): + if self.method in ("ngram", "[ngram]", "ngram_gpu"): # Unified to "ngram" internally - self.method = "ngram" + if self.method in ("ngram", "[ngram]"): + self.method = "ngram" # Set default values if not provided if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 615b1f8489ef..1744134da690 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -21,7 +21,7 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs -from vllm.config.speculative import EagleModelTypes +from vllm.config.speculative import EagleModelTypes, NgramGPUTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -521,10 +521,12 @@ def __post_init__(self): # Currently, async scheduling only support eagle speculative # decoding. if self.speculative_config is not None: - if self.speculative_config.method not in get_args(EagleModelTypes): + if self.speculative_config.method not in get_args( + EagleModelTypes + ) and self.speculative_config.method not in get_args(NgramGPUTypes): raise ValueError( "Currently, async scheduling is only supported " - "with EAGLE/MTP kind of speculative decoding" + "with EAGLE/MTP/NGram GPU kind of speculative decoding" ) if self.speculative_config.disable_padded_drafter_batch: raise ValueError( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 52b98ef65459..441cdd4ac68b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -708,6 +708,8 @@ def schedule(self) -> SchedulerOutput: self.prev_step_scheduled_req_ids.clear() self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys()) + # logger.info(f"In Scheduler Generate Scheduler Output") + # from fpdb import ForkedPdb; ForkedPdb().set_trace() scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -742,6 +744,8 @@ def schedule(self) -> SchedulerOutput: ) scheduler_output.ec_connector_metadata = ec_meta + # logger.info(f"In Scheduler Update After Schedule") + # from fpdb import ForkedPdb; ForkedPdb().set_trace() with record_function_or_nullcontext("schedule: update_after_schedule"): self._update_after_schedule(scheduler_output) return scheduler_output @@ -1037,6 +1041,7 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits kv_connector_output = model_runner_output.kv_connector_output + is_empty_draft_tokens = model_runner_output.is_empty_draft_tokens outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None @@ -1079,6 +1084,9 @@ def update_from_output( sampled_token_ids[req_index] if sampled_token_ids else [] ) + req_is_empty_draft_tokens = ( + is_empty_draft_tokens[req_index] if is_empty_draft_tokens else False + ) scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id) ) @@ -1110,6 +1118,12 @@ def update_from_output( status_before_stop = request.status # Check for stop and update request status. + # logger.info(f"In Scheduler::_update_request_with_output inside loop") + # from fpdb import ForkedPdb; ForkedPdb().set_trace() + + if req_is_empty_draft_tokens: + request.spec_token_ids = [] + if new_token_ids: new_token_ids, stopped = self._update_request_with_output( request, new_token_ids diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 8110deb5a610..6d036b5693ed 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -169,6 +169,9 @@ class ModelRunnerOutput: # req_id -> num_nans_in_logits num_nans_in_logits: dict[str, int] | None = None + # [num_reqs] + is_empty_draft_tokens: list[bool] | None = None + # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py new file mode 100644 index 000000000000..3b25240f4cdf --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -0,0 +1,449 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPU-accelerated N-gram proposer using fully async PyTorch tensor operations. + +This version uses a fully vectorized approach with unfold and argmax for +finding the first match across all sequences in parallel. +""" + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import ( + CompilationConfig, + VllmConfig, +) +from vllm.forward_context import set_forward_context +from vllm.v1.worker.gpu_input_batch import InputBatch + + +@support_torch_compile( + dynamic_arg_dims={ + "num_tokens_no_spec": 0, + "token_ids_gpu": 0, + "combined_mask": 0, + } +) +class NgramGPUKernel(nn.Module): + """ + GPU-accelerated N-gram proposer using fully async tensor operations. + + Interface: All inputs are GPU tensors (no lists, no numpy arrays) + + PERFORMANCE OPTIMIZATION WITH TORCH.COMPILE: + + 1. Tensor Allocation Strategy: + - DO: Allocate tensors inside forward() - torch.compile will optimize this + - DON'T: Pre-allocate buffers as class attributes - breaks compilation + - WHY: torch.compile fuses allocations into the compiled graph for efficiency + + 2. Dynamic Shapes: + - Batch size (dim 0) and sequence length (dim 1) are marked as dynamic + - torch.compile generates specialized kernels for different shapes + - The first call with a new shape will trigger recompilation (cached) + + 3. Graph Compilation: + - Uses fullgraph=True mode for maximum optimization + - All operations are tensor-based (no Python loops or conditionals) + - The entire forward pass is compiled into a single CUDA graph + + 4. Memory Efficiency: + - torch.compile's memory planning optimizes temporary allocations + - Fusion of operations reduces memory bandwidth requirements + - No manual memory management needed - compiler handles it + """ + + def __init__( + self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda" + ): + super().__init__() + + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.device = device + + def _find_first_and_extract_all_n_parallel( + self, + data: torch.Tensor, + seq_lengths: torch.Tensor, + min_pattern_len: int, + max_pattern_len: int, + result_len: int, + ) -> torch.Tensor: + """ + Process all pattern lengths in parallel, selecting the longest match. + Completely free of data-dependent control flow, suitable for + torch.compile optimization. + """ + batch_size = data.shape[0] + device = data.device + max_seq_len = data.shape[1] + num_patterns = max_pattern_len - min_pattern_len + 1 + + all_windows = data.unfold(1, max_pattern_len, 1) # [B, num_windows, max_n] + num_windows = all_windows.shape[1] + window_starts = torch.arange(num_windows, device=device) + pattern_lengths = torch.arange( + min_pattern_len, max_pattern_len + 1, device=device + ) + batch_indices = torch.arange(batch_size, device=device) + + all_first_matches = torch.full( + (batch_size, num_patterns), -1, dtype=torch.long, device=device + ) + + for i, pattern_len in enumerate(range(min_pattern_len, max_pattern_len + 1)): + offset = max_pattern_len - pattern_len + + # Extract pattern from the end of each sequence + pattern_starts = seq_lengths - pattern_len + pattern_indices = pattern_starts.unsqueeze(1) + torch.arange( + pattern_len, device=device + ) + patterns = torch.gather(data, 1, pattern_indices.clamp(min=0)) + + # Slice windows and perform matching + current_windows = all_windows[..., offset:] + matches = (current_windows == patterns.unsqueeze(1)).all(dim=-1) + + # Validity check: ensure enough space for result extraction + max_valid_pattern_start = seq_lengths - pattern_len - result_len + pattern_start_positions = window_starts + offset + valid_mask = pattern_start_positions <= max_valid_pattern_start.unsqueeze(1) + final_matches = matches & valid_mask + + # Handle prefix positions that fall before the available windows + prefix_positions = torch.arange(offset, device=device) + gather_indices = prefix_positions.view(1, -1, 1) + torch.arange( + pattern_len, device=device + ).view(1, 1, -1) + gather_indices = gather_indices.clamp(min=0, max=max_seq_len - 1) + expanded_indices = gather_indices.expand(batch_size, -1, -1) + prefix_tokens = torch.gather( + data.unsqueeze(1).expand(-1, offset, -1), + 2, + expanded_indices, + ) + prefix_matches = ( + prefix_tokens == patterns.unsqueeze(1).expand(-1, offset, -1) + ).all(dim=-1) + prefix_valid_mask = prefix_positions <= max_valid_pattern_start.unsqueeze(1) + prefix_final_matches = prefix_matches & prefix_valid_mask + + combined_matches = torch.cat([prefix_final_matches, final_matches], dim=1) + start_positions = torch.cat( + [prefix_positions, pattern_start_positions], dim=0 + ) + + # Find first match + # (if no match, argmax returns 0, but we verify with has_match) + first_indices = torch.argmax(combined_matches.int(), dim=1) + has_match = combined_matches[batch_indices, first_indices] + match_positions = start_positions[first_indices] + + # Store valid match positions + all_first_matches[:, i] = torch.where(has_match, match_positions, -1) + + # Select the longest valid match, + # from back to front, prioritizing longer patterns + best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1) + best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back + + # Extract corresponding results + best_match_pos = all_first_matches[batch_indices, best_pattern_idx] + + # Handle matched cases - completely avoid data-dependent branching + has_any_match = best_match_pos >= 0 + + best_pattern_lengths = pattern_lengths[best_pattern_idx] + + # Calculate result start positions, invalid positions will be + # clamped to valid range. We now track true start positions, so the + # result starts right after the matched n-gram + result_starts = torch.where( + has_any_match, + best_match_pos + best_pattern_lengths, + torch.zeros_like(best_match_pos), + ) + + # Create gather indices + result_indices = result_starts.unsqueeze(1) + torch.arange( + result_len, device=device + ) + # Ensure indices are within valid range + result_indices = result_indices.clamp(min=0, max=max_seq_len - 1) + + # Always execute gather (even for invalid data) + extracted_sequences = torch.gather(data, 1, result_indices) + + # Use where to zero out invalid results + results = torch.where( + has_any_match.unsqueeze(1), + extracted_sequences, + torch.full_like(extracted_sequences, 0), + ) + + return results + + def forward( + self, + num_tokens_no_spec: torch.Tensor, + token_ids_gpu: torch.Tensor, + combined_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for N-gram proposal using GPU tensor operations. + + This is the core computation method that will be compiled by torch.compile + via the @support_torch_compile decorator. + + Args: + num_tokens_no_spec: Number of tokens for each sequence [batch_size] + token_ids_gpu: Token IDs [batch_size, max_len] + combined_mask: Whether each sequence is valid for spec decode [batch_size] + + Returns: + draft_tokens: [batch_size, k] on GPU + is_empty_draft_tokens: [batch_size] bool on GPU + """ + + device = token_ids_gpu.device + + # Infer batch_size from the input tensor shape to maintain dynamic shape + actual_batch_size = token_ids_gpu.shape[0] + + # Initialize output tensor - torch.compile will optimize this allocation + # NOTE(patchy): Do NOT pre-allocate this as a buffer + # it would break torch.compile + draft_tokens = torch.full( + (actual_batch_size, self.k), -1, dtype=torch.int32, device=device + ) + + results = self._find_first_and_extract_all_n_parallel( + token_ids_gpu, + num_tokens_no_spec, + min_pattern_len=self.min_n, + max_pattern_len=self.max_n, + result_len=self.k, + ) + + # Apply combined mask to results. Expand mask explicitly to avoid + # relying on broadcasting behavior that can confuse torch.compile. + mask = combined_mask.unsqueeze(1).expand(-1, self.k) + draft_tokens = torch.where(mask, results, draft_tokens) + + is_empty_draft_tokens = (draft_tokens == 0).all(dim=1) + + return draft_tokens, is_empty_draft_tokens + + def load_model(self, *args, **kwargs): + """No model to load for N-gram proposer.""" + pass + + +class NgramProposerGPU: + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + compilation_config = CompilationConfig( + level=3, + custom_ops=["none"], + splitting_ops=[], + compile_sizes=[], + inductor_compile_config={ + "enable_auto_functionalized_v2": False, + "max_autotune": True, + "aggressive_fusion": True, + "triton.autotune_pointwise": True, + "coordinate_descent_tuning": True, + "use_mixed_mm": False, + }, + use_cudagraph=False, + ) + + self.vllm_config = VllmConfig(compilation_config=compilation_config) + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.device = device + + self.kernel = NgramGPUKernel( + vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device + ) + self.device = device + self.kernel.to(device) + self.kernel.eval() + + self._dummy_run() + + def _dummy_run(self): + token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( + batch_size=self.max_num_seqs, + max_seq_len=self.max_model_len, + vocab_size=self.vocab_size, + pattern_len=self.k, + repetition_rate=0.5, + device=self.device, + ) + + combined_mask = ( + sampled_flags + & valid_mask + & (num_tokens < self.max_model_len) + & (num_tokens >= self.min_n) + ) + + for _ in range(3): + with set_forward_context(None, self.vllm_config): + _, _ = self.kernel(num_tokens, token_ids, combined_mask) + + def _generate_dummy_data( + self, + batch_size: int, + max_seq_len: int, + vocab_size: int = 152064, + pattern_len: int = 3, + repetition_rate: float = 0.5, + device: str = "cuda", + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate random test data with n-gram repetitions. + + Args: + batch_size: Number of sequences in the batch + max_seq_len: Maximum sequence length + vocab_size: Vocabulary size for random token generation + pattern_len: Length of patterns to inject for matching + repetition_rate: Rate of n-gram repetitions to inject + device: Device to place tensors on + + Returns: + token_ids: [batch_size, max_seq_len] tensor + num_tokens: [batch_size] tensor + sampled_flags: [batch_size] bool tensor + valid_mask: [batch_size] bool tensor + """ + # Generate random token IDs + token_ids = torch.zeros( + batch_size, + max_seq_len, + dtype=torch.int32, + device=device, + ) + + # Generate random sequence lengths + num_tokens = torch.randint( + pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device + ) + + # All sequences have sampled tokens and are valid + sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + return token_ids, num_tokens, sampled_flags, valid_mask + + def propose( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU + token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU + sampled_flags: torch.Tensor, # [batch_size] bool on GPU + valid_mask: torch.Tensor, # [batch_size] bool on GPU + ) -> tuple[torch.Tensor, torch.Tensor]: + assert token_ids_gpu.device == self.device + assert num_tokens_no_spec.device == self.device + assert sampled_flags.device == self.device + assert valid_mask.device == self.device + + with set_forward_context(None, self.vllm_config): + combined_mask = ( + sampled_flags & valid_mask & (num_tokens_no_spec >= self.min_n) + ) + + draft_tokens, is_empty_draft_tokens = self.kernel( + num_tokens_no_spec, + token_ids_gpu, + combined_mask, + ) + + return draft_tokens, is_empty_draft_tokens + + def update_token_ids_ngram( + self, + sampled_token_ids: torch.Tensor, + gpu_input_batch: InputBatch, + token_ids_gpu: torch.Tensor, + num_tokens_no_spec: torch.Tensor, + discard_request_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + num_reqs = gpu_input_batch.num_reqs + + # Extract backup_next_token_ids from token_ids_gpu using vectorized gather + # For each request i, get token_ids_gpu[i, num_tokens_no_spec[i] - 1] + # This is the last valid token before speculative tokens + backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long() + backup_next_token_ids = torch.gather( + token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1) + ).squeeze(1) + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + # Use discard_request_mask to invalidate sampled tokens for discarded + # requests (e.g., chunked prefill partial requests that should not be + # sampled). Expand mask to match [num_reqs, num_tokens] shape. + # Use masked_fill_ to avoid creating new tensors (no CPU-GPU sync). + discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1) + valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1) + + # Generate a mask for all valid tokens within those requests + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) + + # Use last token if valid, vectorized backup from token_ids_gpu if not + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + backup_next_token_ids, + ) + + return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu + + def load_model(self, *args, **kwargs): + self.kernel.load_model(*args, **kwargs) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 516c76a5e4b1..be7e51f31414 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -123,7 +123,13 @@ def __init__( # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec_cpu_tensor = torch.zeros( + (max_num_reqs,), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy() self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs,), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee28f477a26a..5f87306aa963 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -141,6 +141,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.ngram_proposer_gpu import NgramProposerGPU from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -186,7 +187,9 @@ def __init__( logprobs_tensors: LogprobsTensors | None, invalid_req_indices: list[int], async_output_copy_stream: torch.cuda.Stream, + async_draft_copy_stream: torch.cuda.Stream | None, vocab_size: int, + is_empty_draft_tokens: torch.Tensor | None, ): self._model_runner_output = model_runner_output self._invalid_req_indices = invalid_req_indices @@ -199,9 +202,12 @@ def __init__( self._sampled_token_ids = sampled_token_ids self.vocab_size = vocab_size self._logprobs_tensors = logprobs_tensors + self._is_empty_draft_tokens = is_empty_draft_tokens # Initiate the copy on a separate stream, but do not synchronize it. default_stream = torch.cuda.current_stream() + + # Stream 1: Copy sampled_token_ids and logprobs with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self.sampled_token_ids_cpu = self._sampled_token_ids.to( @@ -214,6 +220,22 @@ def __init__( ) self.async_copy_ready_event.record() + # Stream 2: Copy is_empty_draft_tokens in parallel + self.async_draft_copy_ready_event: torch.Event | None = None + if ( + self._is_empty_draft_tokens is not None + and async_draft_copy_stream is not None + ): + self.async_draft_copy_ready_event = torch.Event() + with torch.cuda.stream(async_draft_copy_stream): + async_draft_copy_stream.wait_stream(default_stream) + self._is_empty_draft_tokens_cpu = self._is_empty_draft_tokens.to( + "cpu", non_blocking=True + ) + self.async_draft_copy_ready_event.record() + else: + self._is_empty_draft_tokens_cpu = None + def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. @@ -225,6 +247,7 @@ def get_output(self) -> ModelRunnerOutput: # Release the device tensors once the copy has completed. del self._logprobs_tensors del self._sampled_token_ids + max_gen_len = self.sampled_token_ids_cpu.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: @@ -242,6 +265,11 @@ def get_output(self) -> ModelRunnerOutput: output.sampled_token_ids = valid_sampled_token_ids if self._logprobs_tensors_cpu: output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens) + if self.async_draft_copy_ready_event is not None: + self.async_draft_copy_ready_event.synchronize() + del self._is_empty_draft_tokens + output.is_empty_draft_tokens = self._is_empty_draft_tokens_cpu.tolist() + del self._is_empty_draft_tokens_cpu return output @@ -373,10 +401,25 @@ def __init__( # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( - NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + NgramProposer + | NgramProposerGPU + | SuffixDecodingProposer + | EagleProposer + | MedusaProposer ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "ngram_gpu": + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) + self.num_tokens_no_spec_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) + self.token_ids_gpu_tensor = torch.zeros( + self.max_num_reqs, + self.max_model_len, + dtype=torch.int32, + device=device, + ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -450,11 +493,15 @@ def __init__( # Separate cuda stream for overlapping transfer of sampled token ids from # GPU to CPU when async scheduling is enabled. self.async_output_copy_stream: torch.cuda.Stream | None = None + # Separate cuda stream for overlapping transfer of draft tokens info + # from GPU to CPU, runs in parallel with async_output_copy_stream. + self.async_draft_copy_stream: torch.cuda.Stream | None = None # cuda event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. self.prepare_inputs_event: torch.Event | None = None if self.use_async_scheduling: self.async_output_copy_stream = torch.cuda.Stream() + self.async_draft_copy_stream = torch.cuda.Stream() self.prepare_inputs_event = torch.Event() # self.cudagraph_batch_sizes sorts in ascending order. @@ -572,6 +619,8 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + # For ngram_gpu: indicates which requests have empty/invalid draft tokens + self._is_empty_draft_tokens: torch.Tensor | None = None self.transfer_event = torch.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), @@ -789,6 +838,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) + # Check if ngram_gpu mode is enabled for incremental GPU tensor updates + is_ngram_gpu = ( + self.speculative_config is not None + and self.speculative_config.method == "ngram_gpu" + ) + # Collect new/resumed requests that need full GPU tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs: list[CachedRequestState] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: @@ -845,6 +903,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self._init_xdrope_positions(req_state) reqs_to_add.append(req_state) + # Track new requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank @@ -941,6 +1002,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] reqs_to_add.append(req_state) + # Track resumed requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) continue # Update the persistent batch. @@ -985,9 +1049,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.spec_token_ids[req_index].clear() self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) - # there are no draft tokens with async scheduling, - # we clear the spec_decoding info in scheduler_output and - # use normal sampling but rejection_sampling. + # For async scheduling with speculative decoding: + # When draft tokens are invalid (e.g., ngram proposer returns all + # zeros), we skip the forward computation for those tokens to save + # resources. However, we must keep scheduled_spec_decode_tokens so + # that the Scheduler can correctly adjust num_computed_tokens and + # num_output_placeholders in update_from_output(). if self.use_async_scheduling: req_state.prev_num_draft_len = num_spec_tokens if num_spec_tokens and self._draft_token_ids is None: @@ -1006,6 +1073,108 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + # Incrementally update ngram_gpu tensors after batch is stable + if is_ngram_gpu: + with record_function_or_nullcontext( + "gpu_model_runner: update_ngram_gpu_tensors_incremental" + ): + self._update_ngram_gpu_tensors_incremental(ngram_gpu_new_reqs) + + def _update_ngram_gpu_tensors_incremental( + self, + new_reqs: list[CachedRequestState], + ) -> None: + """Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu + for ngram GPU proposer. + + This method handles three cases: + 1. First run (no prev_req_id_to_index): full initialization + 2. Index changes due to condense/reorder: GPU scatter to new positions + 3. New/resumed requests: full copy of prompt + output tokens + + Args: + new_reqs: List of new or resumed requests that need full tensor copy. + """ + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + curr_req_id_to_index = self.input_batch.req_id_to_index + + if not curr_req_id_to_index: + return + + # Build set of new request IDs for fast lookup + new_req_ids = {req.req_id for req in new_reqs} + + # Case 1: First run, no previous state + if prev_req_id_to_index is None: + self._ngram_gpu_full_init() + return + + # Case 2: Detect index changes, collect requests needing reorder + reorder_src: list[int] = [] + reorder_dst: list[int] = [] + + for req_id, curr_idx in curr_req_id_to_index.items(): + # Skip new requests (handled separately later) + if req_id in new_req_ids: + continue + + prev_idx = prev_req_id_to_index.get(req_id) + if prev_idx is not None and prev_idx != curr_idx: + reorder_src.append(prev_idx) + reorder_dst.append(curr_idx) + + # GPU scatter reorder + if reorder_src: + src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=self.device) + dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=self.device) + + # Clone to avoid overwriting during scatter + temp_token_ids = self.token_ids_gpu_tensor[src_tensor].clone() + temp_num_tokens = self.num_tokens_no_spec_gpu[src_tensor].clone() + + self.token_ids_gpu_tensor[dst_tensor] = temp_token_ids + self.num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens + + # Case 3: Full copy for new/resumed requests + # Use data already stored in input_batch by add_request() + for req_state in new_reqs: + new_req_idx = curr_req_id_to_index.get(req_state.req_id) + if new_req_idx is None: + continue + + num_tokens = self.input_batch.num_tokens_no_spec[new_req_idx] + if num_tokens > 0: + # Copy from input_batch.token_ids_cpu to GPU + self.token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_( + self.input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens], + non_blocking=True, + ) + self.num_tokens_no_spec_gpu[new_req_idx : new_req_idx + 1].copy_( + self.input_batch.num_tokens_no_spec_cpu_tensor[ + new_req_idx : new_req_idx + 1 + ], + non_blocking=True, + ) + + def _ngram_gpu_full_init(self) -> None: + """Initialize all GPU tensors for ngram proposer from scratch. + + Called on first run when there's no previous batch state. + Uses data already stored in input_batch. + """ + for idx in self.input_batch.req_id_to_index.values(): + num_tokens = self.input_batch.num_tokens_no_spec[idx] + if num_tokens > 0: + # Copy from input_batch.token_ids_cpu to GPU + self.token_ids_gpu_tensor[idx, :num_tokens].copy_( + self.input_batch.token_ids_cpu_tensor[idx, :num_tokens], + non_blocking=True, + ) + self.num_tokens_no_spec_gpu[idx : idx + 1].copy_( + self.input_batch.num_tokens_no_spec_cpu_tensor[idx : idx + 1], + non_blocking=True, + ) + def _update_states_after_model_execute( self, output_token_ids: torch.Tensor ) -> None: @@ -1244,6 +1413,7 @@ def _prepare_input_ids( # so convert draft_token_ids to torch.int32 here. draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) self._draft_token_ids = None + self._is_empty_draft_tokens = None self.input_ids.gpu.scatter_( dim=0, @@ -3140,6 +3310,11 @@ def propose_draft_token_ids(sampled_token_ids): and spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch ) + use_padded_batch_for_ngram = ( + self.speculative_config + and self.speculative_config.method == "ngram_gpu" + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len @@ -3178,6 +3353,27 @@ def propose_draft_token_ids(sampled_token_ids): next_token_ids, valid_sampled_tokens_count ) + if use_padded_batch_for_ngram: + assert self.speculative_config is not None + assert isinstance(self.drafter, NgramProposerGPU) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, @@ -3198,7 +3394,7 @@ def propose_draft_token_ids(sampled_token_ids): if ( self.speculative_config - and not use_padded_batch_for_eagle + and not (use_padded_batch_for_eagle or use_padded_batch_for_ngram) and input_fits_in_drafter ): # ngram and other speculative decoding methods use the sampled @@ -3227,13 +3423,17 @@ def propose_draft_token_ids(sampled_token_ids): with record_function_or_nullcontext( "gpu_model_runner: AsyncGPUModelRunnerOutput" ): + sampled_token_ids = sampler_output.sampled_token_ids + async_output = AsyncGPUModelRunnerOutput( model_runner_output=output, - sampled_token_ids=sampler_output.sampled_token_ids, + sampled_token_ids=sampled_token_ids, logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, + async_draft_copy_stream=self.async_draft_copy_stream, vocab_size=self.input_batch.vocab_size, + is_empty_draft_tokens=self._is_empty_draft_tokens, ) with record_function_or_nullcontext( "gpu_model_runner: set_async_sampled_token_ids" @@ -3256,6 +3456,7 @@ def take_draft_token_ids(self) -> DraftTokenIds | None: else: draft_token_ids = self._draft_token_ids self._draft_token_ids = None + self._is_empty_draft_tokens = None return DraftTokenIds(req_ids, draft_token_ids) def _copy_valid_sampled_token_count( @@ -3304,15 +3505,94 @@ def propose_draft_token_ids( spec_config = self.speculative_config assert spec_config is not None if spec_config.method == "ngram": - assert isinstance(sampled_token_ids, list) - assert isinstance(self.drafter, NgramProposer) - draft_token_ids = self.drafter.propose( + # TODO:(patchy) NGram GPU proposal + if isinstance(self.drafter, NgramProposer): + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list whenngram is used." + ) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs, + ) + elif spec_config.method == "ngram_gpu": + # GPU-accelerated ngram proposer + assert isinstance(self.drafter, NgramProposerGPU) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for ngram_gpu" + ) + ( + next_token_ids, + valid_sampled_tokens_count, + valid_sampled_token_ids_gpu, + ) = self.drafter.update_token_ids_ngram( sampled_token_ids, - self.input_batch.req_ids, - self.input_batch.num_tokens_no_spec, - self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1 + + current_lens = self.num_tokens_no_spec_gpu[:batch_size] + offsets = torch.arange(max_new_tokens, device=self.device) + + write_positions = current_lens.unsqueeze(1) + offsets.unsqueeze(0) + valid_write_mask = offsets.unsqueeze( + 0 + ) < valid_sampled_tokens_count.unsqueeze(1) + combined_mask = valid_write_mask & (valid_sampled_token_ids_gpu != -1) + + token_ids_slice = self.token_ids_gpu_tensor[:batch_size] + write_positions_long = write_positions.long() + existing_values = token_ids_slice.gather(1, write_positions_long) + + tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_slice.dtype) + tokens_to_scatter = torch.where( + combined_mask, + tokens_cast, + existing_values, + ) + token_ids_slice.scatter_(1, write_positions_long, tokens_to_scatter) + + self.num_tokens_no_spec_gpu[:batch_size] += valid_sampled_tokens_count + + sampled_flags = valid_sampled_tokens_count > 0 + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device) + + if self.input_batch.spec_decode_unsupported_reqs: + unsupported_ids = torch.tensor( + list(self.input_batch.spec_decode_unsupported_reqs), + dtype=torch.long, + device=self.device, + ) + + batch_req_ids = torch.tensor( + self.input_batch.req_ids[:batch_size], + dtype=torch.long, + device=self.device, + ) + + is_unsupported = ( + batch_req_ids.unsqueeze(1) == unsupported_ids.unsqueeze(0) + ).any(dim=1) + valid_mask = valid_mask & ~is_unsupported + + draft_token_ids, is_empty_draft_tokens = self.drafter.propose( + self.num_tokens_no_spec_gpu[:batch_size], + self.token_ids_gpu_tensor[:batch_size], + sampled_flags, + valid_mask, ) + # Cache is_empty_draft_tokens for filtering in _update_states + self._is_empty_draft_tokens = is_empty_draft_tokens elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer)