From c92e4b88f506b81e83f2f91f1bd3f235a2feda88 Mon Sep 17 00:00:00 2001 From: Huamin Li <3ericli@gmail.com> Date: Thu, 20 Nov 2025 23:53:30 -0800 Subject: [PATCH 1/7] [CI Failure] Fix Gemma3 RoPE configuration for sliding attention layers (#29111) Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Cyrus Leung Signed-off-by: PatchouliTaisa --- vllm/model_executor/models/gemma3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 565719ae7fae..4ad6fc89dcaf 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -166,10 +166,12 @@ def __init__( else: # Transformers v4 rope config. # Global attention. Use the values in config.json. - rope_parameters = config.rope_parameters.copy() + rope_parameters = config.rope_parameters # Local attention. Override the values in config.json. if self.is_sliding: - rope_parameters["rope_theta"] = config.rope_local_base_freq + rope_parameters = dict( + rope_type="default", rope_theta=config.rope_local_base_freq + ) self.rotary_emb = get_rope( self.head_dim, From d183dcb1ef318f024d606e43a05a76cee976fa54 Mon Sep 17 00:00:00 2001 From: PatchouliTaisa Date: Fri, 21 Nov 2025 19:20:18 +0800 Subject: [PATCH 2/7] fix typo error Signed-off-by: PatchouliTaisa --- tests/v1/e2e/test_async_scheduling.py | 43 +- vllm/config/speculative.py | 9 +- vllm/config/vllm.py | 8 +- vllm/v1/spec_decode/ngram_proposer_gpu.py | 503 ++++++++++++++++++++++ vllm/v1/worker/gpu_input_batch.py | 6 + vllm/v1/worker/gpu_model_runner.py | 275 +++++++++++- 6 files changed, 826 insertions(+), 18 deletions(-) create mode 100644 vllm/v1/spec_decode/ngram_proposer_gpu.py diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 00d93e1ba0b5..33e562a31096 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -73,7 +73,7 @@ def test_without_spec_decoding( run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): +def test_with_eagle3_spec_decoding(monkeypatch: pytest.MonkeyPatch): """Test consistency and acceptance rates with some different combos of preemption, executor, async scheduling, prefill chunking, spec decoding model length. @@ -106,6 +106,42 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) +def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test ngram_gpu speculative decoding with different configurations. + + This test specifically validates ngram_gpu behavior with various: + - Number of speculative tokens (2-6) + - Prompt lookup window sizes (min/max) + - Async scheduling enabled (as in production) + - Different executors and chunking settings + """ + + # Variant with larger speculation window + ngram_gpu_config = { + "method": "ngram_gpu", + "num_speculative_tokens": 3, + "prompt_lookup_max": 3, + "prompt_lookup_min": 2, + } + + # Test configurations covering various scenarios + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", False, None, False), + (False, "mp", False, ngram_gpu_config, False), + (True, "mp", False, ngram_gpu_config, True), + (False, "mp", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, False), + (True, "uni", True, ngram_gpu_config, False), + (True, "mp", True, ngram_gpu_config, True), + ] + + # Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight + # and ngram_gpu doesn't require a specific draft model + run_tests(monkeypatch, MODEL, test_configs, [{}]) + + @dynamo_config.patch(cache_size_limit=16) def run_tests( monkeypatch: pytest.MonkeyPatch, @@ -217,18 +253,19 @@ def run_test( else dict(gpu_memory_utilization=0.9) ) spec_mml = (spec_config or {}).get("max_model_len") + spec_method = (spec_config or {}).get("method", "none") test_config = ( f"executor={executor}, preemption={test_preemption}, " f"async_sched={async_scheduling}, " f"chunk_prefill={test_prefill_chunking}, " - f"spec_decoding={spec_decoding}, spec_mml={spec_mml}" + f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}" ) print("-" * 80) print(f"---- TESTING {test_str}: {test_config}") print("-" * 80) with VllmRunner( model, - max_model_len=512, + max_model_len=4096, enable_chunked_prefill=test_prefill_chunking, # Force prefill chunking max_num_batched_tokens=48 if test_prefill_chunking else None, diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a0c65b6049e1..d924e9f1c991 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -40,6 +40,7 @@ "pangu_ultra_moe_mtp", ] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +NgramGPUTypes = Literal["ngram_gpu"] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -47,6 +48,7 @@ "draft_model", "suffix", EagleModelTypes, + NgramGPUTypes, ] @@ -260,6 +262,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "ngram_gpu": + self.model = "ngram_gpu" elif self.method == "suffix": self.model = "suffix" else: @@ -274,9 +278,10 @@ def __post_init__(self): ): self.method = "ngram" - if self.method in ("ngram", "[ngram]"): + if self.method in ("ngram", "[ngram]", "ngram_gpu"): # Unified to "ngram" internally - self.method = "ngram" + if self.method in ("ngram", "[ngram]"): + self.method = "ngram" # Set default values if not provided if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index d64e315b4fe3..3c883079a6b3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -21,7 +21,7 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs -from vllm.config.speculative import EagleModelTypes +from vllm.config.speculative import EagleModelTypes, NgramGPUTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -378,10 +378,12 @@ def __post_init__(self): # Currently, async scheduling only support eagle speculative # decoding. if self.speculative_config is not None: - if self.speculative_config.method not in get_args(EagleModelTypes): + if self.speculative_config.method not in get_args( + EagleModelTypes + ) and self.speculative_config.method not in get_args(NgramGPUTypes): raise ValueError( "Currently, async scheduling is only supported " - "with EAGLE/MTP kind of speculative decoding" + "with EAGLE/MTP/NGram GPU kind of speculative decoding" ) if self.speculative_config.disable_padded_drafter_batch: raise ValueError( diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py new file mode 100644 index 000000000000..1dccc8c3e8cc --- /dev/null +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -0,0 +1,503 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +GPU-accelerated N-gram proposer using fully async PyTorch tensor operations. + +This version uses a fully vectorized approach with unfold and argmax for +finding the first match across all sequences in parallel. +""" + +import os + +import numpy as np +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, VllmConfig +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, +) +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +if int(os.environ.get("NVPROF", "0")) == 1: + pass +else: + pass + +import logging + +from vllm.config import set_current_vllm_config +from vllm.forward_context import set_forward_context + +logger = logging.getLogger(__name__) + + +@support_torch_compile( + dynamic_arg_dims={ + "num_tokens_no_spec": 0, # batch dimension is dynamic + "token_ids_gpu": [0, 1], # both batch and sequence length are dynamic + "sampled_flags": 0, # batch dimension is dynamic + "valid_mask": 0, # batch dimension is dynamic + } +) +class NgramGPUKernel(nn.Module): + """ + GPU-accelerated N-gram proposer using fully async tensor operations. + + Interface: All inputs are GPU tensors (no lists, no numpy arrays) + + PERFORMANCE OPTIMIZATION WITH TORCH.COMPILE: + + 1. Tensor Allocation Strategy: + - DO: Allocate tensors inside forward() - torch.compile will optimize this + - DON'T: Pre-allocate buffers as class attributes - breaks compilation + - WHY: torch.compile fuses allocations into the compiled graph for efficiency + + 2. Dynamic Shapes: + - Batch size (dim 0) and sequence length (dim 1) are marked as dynamic + - torch.compile generates specialized kernels for different shapes + - The first call with a new shape will trigger recompilation (cached) + + 3. Graph Compilation: + - Uses fullgraph=True mode for maximum optimization + - All operations are tensor-based (no Python loops or conditionals) + - The entire forward pass is compiled into a single CUDA graph + + 4. Memory Efficiency: + - torch.compile's memory planning optimizes temporary allocations + - Fusion of operations reduces memory bandwidth requirements + - No manual memory management needed - compiler handles it + """ + + def __init__( + self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda" + ): + super().__init__() + + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.device = device + + def _find_first_and_extract_all_n_parallel( + self, + data: torch.Tensor, + seq_lengths: torch.Tensor, + min_pattern_len: int, + max_pattern_len: int, + result_len: int, + ) -> torch.Tensor: + """ + Process all pattern lengths in parallel, selecting the longest match. + Completely free of data-dependent control flow, suitable for + torch.compile optimization. + """ + batch_size = data.shape[0] + device = data.device + max_seq_len = data.shape[1] + num_patterns = max_pattern_len - min_pattern_len + 1 + + # Create sliding windows once + all_windows = data.unfold(1, max_pattern_len, 1) # [B, num_windows, max_n] + num_windows = all_windows.shape[1] + window_starts = torch.arange(num_windows, device=device) + + # Store the first match position for each pattern length + all_first_matches = torch.full( + (batch_size, num_patterns), -1, dtype=torch.long, device=device + ) + + for i, pattern_len in enumerate(range(min_pattern_len, max_pattern_len + 1)): + offset = max_pattern_len - pattern_len + + # Extract pattern from the end of each sequence + pattern_starts = seq_lengths - pattern_len + pattern_indices = pattern_starts.unsqueeze(1) + torch.arange( + pattern_len, device=device + ) + patterns = torch.gather(data, 1, pattern_indices.clamp(min=0)) + + # Slice windows and perform matching + current_windows = all_windows[..., offset:] + matches = (current_windows == patterns.unsqueeze(1)).all(dim=-1) + + # Validity check: ensure enough space for result extraction + max_valid_start = seq_lengths - pattern_len - result_len + valid_mask = window_starts <= max_valid_start.unsqueeze(1) + final_matches = matches & valid_mask + + # Find first match + # (if no match, argmax returns 0, but we verify with has_match) + first_indices = torch.argmax(final_matches.int(), dim=1) + has_match = final_matches[torch.arange(batch_size), first_indices] + + # Store valid match positions + all_first_matches[:, i] = torch.where(has_match, first_indices, -1) + + # Select the longest valid match, + # from back to front, prioritizing longer patterns + best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1) + best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back + best_pattern_len = min_pattern_len + best_pattern_idx + + # Extract corresponding results + batch_idx = torch.arange(batch_size, device=device) + best_match_pos = all_first_matches[batch_idx, best_pattern_idx] + + # Handle matched cases - completely avoid data-dependent branching + has_any_match = best_match_pos >= 0 + + # Calculate result start positions, invalid positions will be + # clamped to valid range. Since all windows have size max_pattern_len, + # and patterns are matched at the END of windows (due to offset), + # the result starts after the full window + result_starts = torch.where( + has_any_match, + best_match_pos + + max_pattern_len, # Use max_pattern_len, not best_pattern_len + torch.zeros_like(best_match_pos), # Use 0 for no match + ) + + # Create gather indices + result_indices = result_starts.unsqueeze(1) + torch.arange( + result_len, device=device + ) + # Ensure indices are within valid range + result_indices = result_indices.clamp(min=0, max=max_seq_len - 1) + + # Always execute gather (even for invalid data) + extracted_sequences = torch.gather(data, 1, result_indices) + + # Use where to zero out invalid results + results = torch.where( + has_any_match.unsqueeze(1), + extracted_sequences, + torch.zeros_like(extracted_sequences), + ) + + return results + + def forward( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU + token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU + sampled_flags: torch.Tensor, # [batch_size] bool on GPU + valid_mask: torch.Tensor, # [batch_size] bool on GPU + ) -> torch.Tensor: + """ + Forward pass for N-gram proposal using GPU tensor operations. + + This is the core computation method that will be compiled by torch.compile + via the @support_torch_compile decorator. + + Args: + num_tokens_no_spec: Number of tokens for each sequence [batch_size] + token_ids_gpu: Token IDs [batch_size, max_len] + sampled_flags: Whether each sequence has sampled tokens [batch_size] + valid_mask: Whether each sequence is valid for spec decode [batch_size] + + Returns: + draft_tokens: [batch_size, k] on GPU + """ + assert token_ids_gpu.device == self.device + assert num_tokens_no_spec.device == self.device + assert sampled_flags.device == self.device + assert valid_mask.device == self.device + + # Compute combined mask for valid sequences + combined_mask = ( + sampled_flags + & valid_mask + & (num_tokens_no_spec < self.max_model_len) + & (num_tokens_no_spec >= self.min_n) + ) + + batch_size = token_ids_gpu.size(0) + device = token_ids_gpu.device + + # Initialize output tensor - torch.compile will optimize this allocation + # NOTE: Do NOT pre-allocate this as a buffer - it would break torch.compile + draft_tokens = torch.zeros( + (batch_size, self.k), dtype=torch.int32, device=device + ) + + # Use the async find and extract method with max_n pattern length + # This will find the first match and extract k tokens + results = self._find_first_and_extract_all_n_parallel( + token_ids_gpu, + num_tokens_no_spec, + min_pattern_len=self.min_n, + max_pattern_len=self.max_n, + result_len=self.k, + ) + + # Apply combined mask to results + draft_tokens = torch.where(combined_mask.unsqueeze(1), results, draft_tokens) + + return draft_tokens + + def load_model(self, *args, **kwargs): + """No model to load for N-gram proposer.""" + pass + + +class NgramProposerGPU: + def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): + assert vllm_config.speculative_config is not None + assert vllm_config.speculative_config.prompt_lookup_min is not None + assert vllm_config.speculative_config.prompt_lookup_max is not None + + # Create optimized compilation config for ngram kernel + compilation_config = CompilationConfig( + level=3, + custom_ops=["none"], + splitting_ops=[], + compile_sizes=[], + inductor_compile_config={ + "enable_auto_functionalized_v2": False, + }, + use_cudagraph=False, + cudagraph_mode=CUDAGraphMode.NONE, + mode=CompilationMode.VLLM_COMPILE, + ) + + self.vllm_config = VllmConfig(compilation_config=compilation_config) + + self.min_n = vllm_config.speculative_config.prompt_lookup_min + self.max_n = vllm_config.speculative_config.prompt_lookup_max + self.k = vllm_config.speculative_config.num_speculative_tokens + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs + self.vocab_size = vllm_config.model_config.get_vocab_size() + self.device = device + + with set_current_vllm_config(self.vllm_config, check_compile=False): + self.kernel = NgramGPUKernel( + vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device + ) + self.device = device + self.kernel.to(device) + self.kernel.eval() + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + + self._dummy_run() + + def _dummy_run(self): + # with set_current_vllm_config(self.vllm_config, check_compile=False): + # Get warmup iterations from config or use default + token_ids, num_tokens, sampled_flags, valid_mask = ( + self._generate_dummy_data( + batch_size=self.max_num_seqs, + max_seq_len=min( + self.max_model_len, 1024 + ), # Use reasonable seq len for warmup + vocab_size=self.vocab_size, + pattern_len=self.k, + repetition_rate=0.5, + device=self.device, + ) + ) + + for _ in range(3): + with set_forward_context(None, self.vllm_config): + _ = self.kernel( + num_tokens, token_ids, sampled_flags, valid_mask + ) + + def _generate_dummy_data( + self, + batch_size: int, + max_seq_len: int, + vocab_size: int = 152064, + pattern_len: int = 3, + repetition_rate: float = 0.5, + device: str = "cuda", + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generate random test data with n-gram repetitions. + + Args: + batch_size: Number of sequences in the batch + max_seq_len: Maximum sequence length + vocab_size: Vocabulary size for random token generation + pattern_len: Length of patterns to inject for matching + repetition_rate: Rate of n-gram repetitions to inject + device: Device to place tensors on + + Returns: + token_ids: [batch_size, max_seq_len] tensor + num_tokens: [batch_size] tensor + sampled_flags: [batch_size] bool tensor + valid_mask: [batch_size] bool tensor + """ + # Generate random token IDs + token_ids = torch.randint( + 0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device=device + ) + + # Generate random sequence lengths + min_len = max(pattern_len * 2 + 3, max_seq_len // 2) + num_tokens = torch.randint( + min_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device + ) + + # Inject n-gram repetitions using the tail pattern of each sequence + for i in range(batch_size): + seq_len = num_tokens[i].item() + if seq_len > pattern_len * 2: + # Pattern is the last pattern_len tokens of the valid sequence + src_pos = seq_len - pattern_len + num_reps = int(seq_len * repetition_rate / pattern_len) + for _ in range(num_reps): + # Place the copied tail pattern somewhere before the tail + tgt_pos = torch.randint(0, seq_len - pattern_len, (1,)).item() + if tgt_pos == src_pos: + continue + + token_ids[i, tgt_pos : tgt_pos + pattern_len] = token_ids[ + i, src_pos : src_pos + pattern_len + ].clone() + + # All sequences have sampled tokens and are valid + sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + return token_ids, num_tokens, sampled_flags, valid_mask + + def propose( + self, + num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU + token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU + sampled_flags: torch.Tensor, # [batch_size] bool on GPU + valid_mask: torch.Tensor, # [batch_size] bool on GPU + ) -> torch.Tensor: + # with set_current_vllm_config(self.vllm_config, check_compile=False): + with set_forward_context(None, self.vllm_config): + return self.kernel( + num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask + ) + + def prepare_next_token_ids_cpu( + self, + sampled_token_ids: list[np.ndarray], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids.shape[0] > 0: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + return torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + # Batch convert seq_lens to avoid multiple .item() calls + # This performs a single synchronization for all lengths + # instead of one per request + seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() + + # Now use the pre-converted list to avoid .item() calls in the loop + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) + for i in range(num_reqs) + ] + ) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1 + ) + + # Generate a mask for all valid tokens within those requests + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) + + return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu + + def load_model(self, *args, **kwargs): + with set_current_vllm_config(self.vllm_config, check_compile=False): + self.kernel.load_model(*args, **kwargs) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 7b4bc1d2a224..160a4372589f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -112,6 +112,9 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.token_ids_gpu_tensor = torch.zeros( + max_num_reqs, max_model_len, dtype=torch.int32, device=device + ) self.is_token_ids_tensor = torch.zeros( (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False ) @@ -122,6 +125,9 @@ def __init__( self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec_gpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device=device + ) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( (max_num_reqs,), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4c65a5e9b029..ce90aded2785 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -135,6 +135,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.ngram_proposer_gpu import NgramProposerGPU from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -366,10 +367,16 @@ def __init__( # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( - NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + NgramProposer + | NgramProposerGPU + | SuffixDecodingProposer + | EagleProposer + | MedusaProposer ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "ngram_gpu": + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -959,6 +966,146 @@ def _update_states_after_model_execute( for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None: + """Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu + for ngram GPU proposer to avoid redundant CPU-GPU transfers. + + This follows a similar pattern to _prepare_input_ids for efficient + batch updates when requests change between iterations. + """ + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + curr_req_id_to_index = self.input_batch.req_id_to_index + + # If no previous batch or batch is empty, initialize all from scratch + if prev_req_id_to_index is None or not curr_req_id_to_index: + if curr_req_id_to_index: + # Initialize all token_ids from requests + for req_id, idx in curr_req_id_to_index.items(): + req_state = self.requests[req_id] + # Get prompt_token_ids + output_token_ids + prompt_token_ids = ( + req_state.prompt_token_ids + if req_state.prompt_token_ids is not None + else [] + ) + all_token_ids = prompt_token_ids + req_state.output_token_ids + num_tokens = len(all_token_ids) + # Copy to GPU tensor + self.input_batch.token_ids_gpu_tensor[idx, :num_tokens].copy_( + torch.tensor( + all_token_ids, dtype=torch.int32, device=self.device + ), + non_blocking=True, + ) + self.input_batch.num_tokens_no_spec_gpu[idx] = num_tokens + return + + # Case 1: Batch hasn't changed at all (same req_ids and same indices) + if prev_req_id_to_index == curr_req_id_to_index: + return + + # Case 2, 3 & 4: Batch has changed - analyze the changes + common_req_indices = [] + prev_indices = [] + new_req_indices = [] + indices_match = True + + for req_id, curr_idx in curr_req_id_to_index.items(): + if req_id in prev_req_id_to_index: + prev_idx = prev_req_id_to_index[req_id] + common_req_indices.append(curr_idx) + prev_indices.append(prev_idx) + indices_match &= prev_idx == curr_idx + else: + new_req_indices.append((req_id, curr_idx)) + + # Case 2: Only common requests (subset or same set), may need reordering or clearing + if not new_req_indices: + # If indices haven't changed and it's the exact same set, already handled by Case 1 + # So here we either have reordering or a subset (some requests finished) + if not indices_match or len(common_req_indices) < len(prev_req_id_to_index): + # Need to reorder or clear finished requests + curr_indices_tensor = torch.tensor( + common_req_indices, dtype=torch.long, device=self.device + ) + prev_indices_tensor = torch.tensor( + prev_indices, dtype=torch.long, device=self.device + ) + + # Create temporary tensors for scatter operation (zeros will clear unused positions) + temp_token_ids = torch.zeros_like(self.input_batch.token_ids_gpu_tensor) + temp_num_tokens = torch.zeros_like( + self.input_batch.num_tokens_no_spec_gpu + ) + + # Scatter token_ids - copy entire rows (already up-to-date from prepare_next_token_ids_padded) + temp_token_ids[curr_indices_tensor] = ( + self.input_batch.token_ids_gpu_tensor[prev_indices_tensor] + ) + temp_num_tokens[curr_indices_tensor] = ( + self.input_batch.num_tokens_no_spec_gpu[prev_indices_tensor] + ) + + # Update in-place + self.input_batch.token_ids_gpu_tensor.copy_( + temp_token_ids, non_blocking=True + ) + self.input_batch.num_tokens_no_spec_gpu.copy_( + temp_num_tokens, non_blocking=True + ) + return + + # Case 3: Has new requests (or preempted requests that are resuming) + if new_req_indices: + # First handle common requests with scatter if any + if common_req_indices: + curr_indices_tensor = torch.tensor( + common_req_indices, dtype=torch.long, device=self.device + ) + prev_indices_tensor = torch.tensor( + prev_indices, dtype=torch.long, device=self.device + ) + + # Create temporary tensors for vectorized update + temp_token_ids = torch.zeros_like(self.input_batch.token_ids_gpu_tensor) + temp_num_tokens = torch.zeros_like( + self.input_batch.num_tokens_no_spec_gpu + ) + + # Scatter existing requests to new positions + temp_token_ids[curr_indices_tensor] = ( + self.input_batch.token_ids_gpu_tensor[prev_indices_tensor] + ) + temp_num_tokens[curr_indices_tensor] = ( + self.input_batch.num_tokens_no_spec_gpu[prev_indices_tensor] + ) + + # Copy back to persistent tensors + self.input_batch.token_ids_gpu_tensor.copy_( + temp_token_ids, non_blocking=True + ) + self.input_batch.num_tokens_no_spec_gpu.copy_( + temp_num_tokens, non_blocking=True + ) + + # Then handle new requests + for req_id, curr_idx in new_req_indices: + req_state = self.requests[req_id] + # Get prompt_token_ids + output_token_ids + prompt_token_ids = ( + req_state.prompt_token_ids + if req_state.prompt_token_ids is not None + else [] + ) + all_token_ids = prompt_token_ids + req_state.output_token_ids + num_tokens = len(all_token_ids) + # Copy to GPU tensor with non-blocking + self.input_batch.token_ids_gpu_tensor[curr_idx, :num_tokens].copy_( + torch.tensor(all_token_ids, dtype=torch.int32, device=self.device), + non_blocking=True, + ) + self.input_batch.num_tokens_no_spec_gpu[curr_idx] = num_tokens + def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() assert supports_mrope(model), "M-RoPE support is not implemented." @@ -1345,6 +1492,10 @@ def _prepare_inputs( cu_num_tokens, ) + # For ngram GPU proposer: update token_ids and num_tokens incrementally + if self.speculative_config and self.speculative_config.method == "ngram_gpu": + self._update_ngram_gpu_tensors(scheduler_output) + if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( @@ -2950,6 +3101,11 @@ def propose_draft_token_ids(sampled_token_ids): and spec_config.use_eagle() and not spec_config.disable_padded_drafter_batch ) + use_padded_batch_for_ngram = ( + self.speculative_config + and self.speculative_config.method == "ngram_gpu" + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len @@ -2989,6 +3145,27 @@ def propose_draft_token_ids(sampled_token_ids): next_token_ids, valid_sampled_tokens_count ) + if use_padded_batch_for_ngram: + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + # Fast path: GPU-only operation when input fits in drafter + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + # Slow path: prepare tokens with async transfer + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( num_nans_in_logits, @@ -3009,7 +3186,7 @@ def propose_draft_token_ids(sampled_token_ids): if ( self.speculative_config - and not use_padded_batch_for_eagle + and not (use_padded_batch_for_eagle or use_padded_batch_for_ngram) and input_fits_in_drafter ): # ngram and other speculative decoding methods use the sampled @@ -3114,15 +3291,93 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None - if spec_config.method == "ngram": - assert isinstance(sampled_token_ids, list) - assert isinstance(self.drafter, NgramProposer) + if self.speculative_config.method == "ngram": + # TODO:(patchy) NGram GPU proposal + if isinstance(self.drafter, NgramProposer): + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list whenngram is used." + ) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.req_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + self.input_batch.spec_decode_unsupported_reqs, + ) + elif self.speculative_config.method == "ngram_gpu": + # GPU-accelerated ngram proposer + assert isinstance(self.drafter, NgramProposerGPU) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for ngram_gpu" + ) + next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1 + + current_lens = self.input_batch.num_tokens_no_spec_gpu[:batch_size] + offsets = torch.arange(max_new_tokens, device=self.device) + + write_positions = current_lens.unsqueeze(1) + offsets.unsqueeze(0) + valid_write_mask = offsets.unsqueeze( + 0 + ) < valid_sampled_tokens_count.unsqueeze(1) + combined_mask = valid_write_mask & (valid_sampled_token_ids_gpu != -1) + + token_ids_slice = self.input_batch.token_ids_gpu_tensor[:batch_size] + write_positions_long = write_positions.long() + existing_values = token_ids_slice.gather(1, write_positions_long) + + tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_slice.dtype) + tokens_to_scatter = torch.where( + combined_mask, + tokens_cast, + existing_values, + ) + token_ids_slice.scatter_(1, write_positions_long, tokens_to_scatter) + + self.input_batch.num_tokens_no_spec_gpu[:batch_size] += ( + valid_sampled_tokens_count + ) + + sampled_flags = valid_sampled_tokens_count > 0 + valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device) + + if self.input_batch.spec_decode_unsupported_reqs: + unsupported_ids = torch.tensor( + list(self.input_batch.spec_decode_unsupported_reqs), + dtype=torch.long, + device=self.device, + ) + + batch_req_ids = torch.tensor( + self.input_batch.req_ids[:batch_size], + dtype=torch.long, + device=self.device, + ) + + is_unsupported = ( + batch_req_ids.unsqueeze(1) == unsupported_ids.unsqueeze(0) + ).any(dim=1) + valid_mask = valid_mask & ~is_unsupported + draft_token_ids = self.drafter.propose( - sampled_token_ids, - self.input_batch.req_ids, - self.input_batch.num_tokens_no_spec, - self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs, + self.input_batch.num_tokens_no_spec_gpu[:batch_size], + self.input_batch.token_ids_gpu_tensor[:batch_size], + sampled_flags, + valid_mask, ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) From 293e3aef0f85392348920fa3ce5635c1b8539ea5 Mon Sep 17 00:00:00 2001 From: PatchouliTaisa Date: Fri, 21 Nov 2025 19:55:28 +0800 Subject: [PATCH 3/7] fix return values in ngram gpu Signed-off-by: PatchouliTaisa --- vllm/v1/spec_decode/ngram_proposer_gpu.py | 96 ++++++++--------------- 1 file changed, 34 insertions(+), 62 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index 1dccc8c3e8cc..47f880e10c7a 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -6,12 +6,9 @@ This version uses a fully vectorized approach with unfold and argmax for finding the first match across all sequences in parallel. """ - -import os - -import numpy as np import torch from torch import nn +import numpy as np from vllm.compilation.decorators import support_torch_compile from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, VllmConfig @@ -19,21 +16,11 @@ from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, ) -from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -if int(os.environ.get("NVPROF", "0")) == 1: - pass -else: - pass - -import logging - from vllm.config import set_current_vllm_config from vllm.forward_context import set_forward_context - -logger = logging.getLogger(__name__) - +from vllm.v1.utils import CpuGpuBuffer @support_torch_compile( dynamic_arg_dims={ @@ -107,12 +94,10 @@ def _find_first_and_extract_all_n_parallel( max_seq_len = data.shape[1] num_patterns = max_pattern_len - min_pattern_len + 1 - # Create sliding windows once all_windows = data.unfold(1, max_pattern_len, 1) # [B, num_windows, max_n] num_windows = all_windows.shape[1] window_starts = torch.arange(num_windows, device=device) - # Store the first match position for each pattern length all_first_matches = torch.full( (batch_size, num_patterns), -1, dtype=torch.long, device=device ) @@ -148,7 +133,6 @@ def _find_first_and_extract_all_n_parallel( # from back to front, prioritizing longer patterns best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1) best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back - best_pattern_len = min_pattern_len + best_pattern_idx # Extract corresponding results batch_idx = torch.arange(batch_size, device=device) @@ -163,9 +147,8 @@ def _find_first_and_extract_all_n_parallel( # the result starts after the full window result_starts = torch.where( has_any_match, - best_match_pos - + max_pattern_len, # Use max_pattern_len, not best_pattern_len - torch.zeros_like(best_match_pos), # Use 0 for no match + best_match_pos + max_pattern_len, + torch.zeros_like(best_match_pos), ) # Create gather indices @@ -226,13 +209,12 @@ def forward( device = token_ids_gpu.device # Initialize output tensor - torch.compile will optimize this allocation - # NOTE: Do NOT pre-allocate this as a buffer - it would break torch.compile + # NOTE(patchy): Do NOT pre-allocate this as a buffer + # it would break torch.compile draft_tokens = torch.zeros( (batch_size, self.k), dtype=torch.int32, device=device ) - # Use the async find and extract method with max_n pattern length - # This will find the first match and extract k tokens results = self._find_first_and_extract_all_n_parallel( token_ids_gpu, num_tokens_no_spec, @@ -250,28 +232,31 @@ def load_model(self, *args, **kwargs): """No model to load for N-gram proposer.""" pass - class NgramProposerGPU: def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.prompt_lookup_min is not None assert vllm_config.speculative_config.prompt_lookup_max is not None - # Create optimized compilation config for ngram kernel compilation_config = CompilationConfig( - level=3, + level=3, custom_ops=["none"], splitting_ops=[], compile_sizes=[], inductor_compile_config={ - "enable_auto_functionalized_v2": False, + "enable_auto_functionalized_v2": False, + "max_autotune": True, + "aggressive_fusion": True, + "triton.autotune_pointwise": True, + "coordinate_descent_tuning": True, + "use_mixed_mm": False, }, use_cudagraph=False, - cudagraph_mode=CUDAGraphMode.NONE, - mode=CompilationMode.VLLM_COMPILE, ) - self.vllm_config = VllmConfig(compilation_config=compilation_config) + self.vllm_config = VllmConfig( + compilation_config=compilation_config + ) self.min_n = vllm_config.speculative_config.prompt_lookup_min self.max_n = vllm_config.speculative_config.prompt_lookup_max @@ -282,9 +267,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): self.device = device with set_current_vllm_config(self.vllm_config, check_compile=False): - self.kernel = NgramGPUKernel( - vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device - ) + self.kernel = NgramGPUKernel(vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device) self.device = device self.kernel.to(device) self.kernel.eval() @@ -300,26 +283,19 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): self._dummy_run() def _dummy_run(self): - # with set_current_vllm_config(self.vllm_config, check_compile=False): - # Get warmup iterations from config or use default - token_ids, num_tokens, sampled_flags, valid_mask = ( - self._generate_dummy_data( + with set_current_vllm_config(self.vllm_config, check_compile=False): + token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( batch_size=self.max_num_seqs, - max_seq_len=min( - self.max_model_len, 1024 - ), # Use reasonable seq len for warmup + max_seq_len=min(self.max_model_len, 1024), vocab_size=self.vocab_size, pattern_len=self.k, repetition_rate=0.5, - device=self.device, + device=self.device ) - ) - for _ in range(3): - with set_forward_context(None, self.vllm_config): - _ = self.kernel( - num_tokens, token_ids, sampled_flags, valid_mask - ) + for _ in range(3): + with set_forward_context(None, self.vllm_config): + output = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask) def _generate_dummy_data( self, @@ -349,13 +325,15 @@ def _generate_dummy_data( """ # Generate random token IDs token_ids = torch.randint( - 0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device=device + 0, vocab_size, (batch_size, max_seq_len), + dtype=torch.int32, device=device ) # Generate random sequence lengths min_len = max(pattern_len * 2 + 3, max_seq_len // 2) num_tokens = torch.randint( - min_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device + min_len, max_seq_len, (batch_size,), + dtype=torch.int32, device=device ) # Inject n-gram repetitions using the tail pattern of each sequence @@ -371,9 +349,8 @@ def _generate_dummy_data( if tgt_pos == src_pos: continue - token_ids[i, tgt_pos : tgt_pos + pattern_len] = token_ids[ - i, src_pos : src_pos + pattern_len - ].clone() + token_ids[i, tgt_pos:tgt_pos + pattern_len] = \ + token_ids[i, src_pos:src_pos + pattern_len].clone() # All sequences have sampled tokens and are valid sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) @@ -388,11 +365,9 @@ def propose( sampled_flags: torch.Tensor, # [batch_size] bool on GPU valid_mask: torch.Tensor, # [batch_size] bool on GPU ) -> torch.Tensor: - # with set_current_vllm_config(self.vllm_config, check_compile=False): - with set_forward_context(None, self.vllm_config): - return self.kernel( - num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask - ) + with set_current_vllm_config(self.vllm_config, check_compile=False): + with set_forward_context(None, self.vllm_config): + return self.kernel(num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask) def prepare_next_token_ids_cpu( self, @@ -443,12 +418,9 @@ def prepare_next_token_ids_padded( should not introduce any blocking CPU-GPU synchronization. """ # TODO(Ben): Combine this into a custom fused kernel - # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs # Batch convert seq_lens to avoid multiple .item() calls - # This performs a single synchronization for all lengths - # instead of one per request seq_lens_list = common_attn_metadata.seq_lens_cpu[:num_reqs].tolist() # Now use the pre-converted list to avoid .item() calls in the loop @@ -500,4 +472,4 @@ def prepare_next_token_ids_padded( def load_model(self, *args, **kwargs): with set_current_vllm_config(self.vllm_config, check_compile=False): - self.kernel.load_model(*args, **kwargs) + self.kernel.load_model(*args, **kwargs) \ No newline at end of file From 4534c886e4f19d35ece44bcdffbcde570a0bcbca Mon Sep 17 00:00:00 2001 From: PatchouliTaisa Date: Mon, 24 Nov 2025 10:42:39 +0800 Subject: [PATCH 4/7] python3.13 pre-commit check Signed-off-by: PatchouliTaisa --- vllm/v1/spec_decode/ngram_proposer_gpu.py | 122 +++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 12 ++- 2 files changed, 71 insertions(+), 63 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index 47f880e10c7a..d748bfea1ad5 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -6,21 +6,24 @@ This version uses a fully vectorized approach with unfold and argmax for finding the first match across all sequences in parallel. """ + +import numpy as np import torch from torch import nn -import numpy as np from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, VllmConfig +from vllm.config import ( + CompilationConfig, + VllmConfig, +) +from vllm.forward_context import set_forward_context from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, ) +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch -from vllm.config import set_current_vllm_config -from vllm.forward_context import set_forward_context -from vllm.v1.utils import CpuGpuBuffer @support_torch_compile( dynamic_arg_dims={ @@ -121,7 +124,7 @@ def _find_first_and_extract_all_n_parallel( valid_mask = window_starts <= max_valid_start.unsqueeze(1) final_matches = matches & valid_mask - # Find first match + # Find first match # (if no match, argmax returns 0, but we verify with has_match) first_indices = torch.argmax(final_matches.int(), dim=1) has_match = final_matches[torch.arange(batch_size), first_indices] @@ -129,7 +132,7 @@ def _find_first_and_extract_all_n_parallel( # Store valid match positions all_first_matches[:, i] = torch.where(has_match, first_indices, -1) - # Select the longest valid match, + # Select the longest valid match, # from back to front, prioritizing longer patterns best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1) best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back @@ -141,14 +144,14 @@ def _find_first_and_extract_all_n_parallel( # Handle matched cases - completely avoid data-dependent branching has_any_match = best_match_pos >= 0 - # Calculate result start positions, invalid positions will be - # clamped to valid range. Since all windows have size max_pattern_len, - # and patterns are matched at the END of windows (due to offset), + # Calculate result start positions, invalid positions will be + # clamped to valid range. Since all windows have size max_pattern_len, + # and patterns are matched at the END of windows (due to offset), # the result starts after the full window result_starts = torch.where( has_any_match, - best_match_pos + max_pattern_len, - torch.zeros_like(best_match_pos), + best_match_pos + max_pattern_len, + torch.zeros_like(best_match_pos), ) # Create gather indices @@ -232,6 +235,7 @@ def load_model(self, *args, **kwargs): """No model to load for N-gram proposer.""" pass + class NgramProposerGPU: def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): assert vllm_config.speculative_config is not None @@ -239,24 +243,22 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): assert vllm_config.speculative_config.prompt_lookup_max is not None compilation_config = CompilationConfig( - level=3, + level=3, custom_ops=["none"], splitting_ops=[], compile_sizes=[], inductor_compile_config={ - "enable_auto_functionalized_v2": False, - "max_autotune": True, - "aggressive_fusion": True, - "triton.autotune_pointwise": True, - "coordinate_descent_tuning": True, - "use_mixed_mm": False, + "enable_auto_functionalized_v2": False, + "max_autotune": True, + "aggressive_fusion": True, + "triton.autotune_pointwise": True, + "coordinate_descent_tuning": True, + "use_mixed_mm": False, }, use_cudagraph=False, ) - self.vllm_config = VllmConfig( - compilation_config=compilation_config - ) + self.vllm_config = VllmConfig(compilation_config=compilation_config) self.min_n = vllm_config.speculative_config.prompt_lookup_min self.max_n = vllm_config.speculative_config.prompt_lookup_max @@ -266,36 +268,36 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None): self.vocab_size = vllm_config.model_config.get_vocab_size() self.device = device - with set_current_vllm_config(self.vllm_config, check_compile=False): - self.kernel = NgramGPUKernel(vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device) - self.device = device - self.kernel.to(device) - self.kernel.eval() - max_batch_size = vllm_config.scheduler_config.max_num_seqs - self.backup_next_token_ids = CpuGpuBuffer( - max_batch_size, - dtype=torch.int32, - pin_memory=is_pin_memory_available(), - device=device, - with_numpy=True, - ) + self.kernel = NgramGPUKernel( + vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device + ) + self.device = device + self.kernel.to(device) + self.kernel.eval() + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) - self._dummy_run() + self._dummy_run() def _dummy_run(self): - with set_current_vllm_config(self.vllm_config, check_compile=False): - token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( - batch_size=self.max_num_seqs, - max_seq_len=min(self.max_model_len, 1024), - vocab_size=self.vocab_size, - pattern_len=self.k, - repetition_rate=0.5, - device=self.device - ) + token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data( + batch_size=self.max_num_seqs, + max_seq_len=min(self.max_model_len, 1024), + vocab_size=self.vocab_size, + pattern_len=self.k, + repetition_rate=0.5, + device=self.device, + ) - for _ in range(3): - with set_forward_context(None, self.vllm_config): - output = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask) + for _ in range(3): + with set_forward_context(None, self.vllm_config): + _ = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask) def _generate_dummy_data( self, @@ -325,15 +327,13 @@ def _generate_dummy_data( """ # Generate random token IDs token_ids = torch.randint( - 0, vocab_size, (batch_size, max_seq_len), - dtype=torch.int32, device=device + 0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device=device ) # Generate random sequence lengths min_len = max(pattern_len * 2 + 3, max_seq_len // 2) num_tokens = torch.randint( - min_len, max_seq_len, (batch_size,), - dtype=torch.int32, device=device + min_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device ) # Inject n-gram repetitions using the tail pattern of each sequence @@ -349,8 +349,9 @@ def _generate_dummy_data( if tgt_pos == src_pos: continue - token_ids[i, tgt_pos:tgt_pos + pattern_len] = \ - token_ids[i, src_pos:src_pos + pattern_len].clone() + token_ids[i, tgt_pos : tgt_pos + pattern_len] = token_ids[ + i, src_pos : src_pos + pattern_len + ].clone() # All sequences have sampled tokens and are valid sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device) @@ -365,9 +366,13 @@ def propose( sampled_flags: torch.Tensor, # [batch_size] bool on GPU valid_mask: torch.Tensor, # [batch_size] bool on GPU ) -> torch.Tensor: - with set_current_vllm_config(self.vllm_config, check_compile=False): - with set_forward_context(None, self.vllm_config): - return self.kernel(num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask) + with set_forward_context(None, self.vllm_config): + return self.kernel( + num_tokens_no_spec, + token_ids_gpu, + sampled_flags, + valid_mask, + ) def prepare_next_token_ids_cpu( self, @@ -471,5 +476,4 @@ def prepare_next_token_ids_padded( return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu def load_model(self, *args, **kwargs): - with set_current_vllm_config(self.vllm_config, check_compile=False): - self.kernel.load_model(*args, **kwargs) \ No newline at end of file + self.kernel.load_model(*args, **kwargs) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ce90aded2785..70c73c8ce480 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1019,9 +1019,11 @@ def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None else: new_req_indices.append((req_id, curr_idx)) - # Case 2: Only common requests (subset or same set), may need reordering or clearing + # Case 2: Only common requests (subset or same set), + # may need reordering or clearing if not new_req_indices: - # If indices haven't changed and it's the exact same set, already handled by Case 1 + # If indices haven't changed and it's the exact same set, + # already handled by Case 1 # So here we either have reordering or a subset (some requests finished) if not indices_match or len(common_req_indices) < len(prev_req_id_to_index): # Need to reorder or clear finished requests @@ -1032,13 +1034,15 @@ def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None prev_indices, dtype=torch.long, device=self.device ) - # Create temporary tensors for scatter operation (zeros will clear unused positions) + # Create temporary tensors for scatter operation + # (zeros will clear unused positions) temp_token_ids = torch.zeros_like(self.input_batch.token_ids_gpu_tensor) temp_num_tokens = torch.zeros_like( self.input_batch.num_tokens_no_spec_gpu ) - # Scatter token_ids - copy entire rows (already up-to-date from prepare_next_token_ids_padded) + # Scatter token_ids - copy entire rows + # (already up-to-date from prepare_next_token_ids_padded) temp_token_ids[curr_indices_tensor] = ( self.input_batch.token_ids_gpu_tensor[prev_indices_tensor] ) From 07e6b8a0e72bcf95472714bd3bfbf672e81f0d68 Mon Sep 17 00:00:00 2001 From: PatchouliTaisa Date: Mon, 24 Nov 2025 11:03:28 +0800 Subject: [PATCH 5/7] fix pre-commit and sign-off Signed-off-by: PatchouliTaisa --- vllm/v1/worker/gpu_model_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 70c73c8ce480..cd12bd414273 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3150,6 +3150,8 @@ def propose_draft_token_ids(sampled_token_ids): ) if use_padded_batch_for_ngram: + assert self.speculative_config is not None + assert isinstance(self.drafter, NgramProposerGPU) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: # Fast path: GPU-only operation when input fits in drafter @@ -3295,7 +3297,7 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None - if self.speculative_config.method == "ngram": + if spec_config.method == "ngram": # TODO:(patchy) NGram GPU proposal if isinstance(self.drafter, NgramProposer): assert isinstance(sampled_token_ids, list), ( @@ -3308,7 +3310,7 @@ def propose_draft_token_ids( self.input_batch.token_ids_cpu, self.input_batch.spec_decode_unsupported_reqs, ) - elif self.speculative_config.method == "ngram_gpu": + elif spec_config.method == "ngram_gpu": # GPU-accelerated ngram proposer assert isinstance(self.drafter, NgramProposerGPU) assert isinstance(sampled_token_ids, torch.Tensor), ( From e70b0606c27c3b51bac877c602bb53380ee1b1d1 Mon Sep 17 00:00:00 2001 From: PatchouliTaisa Date: Tue, 25 Nov 2025 20:04:08 +0800 Subject: [PATCH 6/7] fix ngram gpu kernel compile issue Signed-off-by: PatchouliTaisa --- vllm/compilation/backends.py | 7 ++ vllm/v1/spec_decode/ngram_proposer_gpu.py | 113 ++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 3 +- 3 files changed, 83 insertions(+), 40 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1e66f21ff638..4fe84ed267ce 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -642,6 +642,13 @@ def __call__( self.compilation_config.inductor_compile_config ) + # TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors. + is_ngram_gpu_enabled = ( + vllm_config.speculative_config + and vllm_config.speculative_config.method == "ngram_gpu" + ) + disable_cache = disable_cache or is_ngram_gpu_enabled + if disable_cache: logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") else: diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index d748bfea1ad5..29366f3ca486 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -27,10 +27,9 @@ @support_torch_compile( dynamic_arg_dims={ - "num_tokens_no_spec": 0, # batch dimension is dynamic - "token_ids_gpu": [0, 1], # both batch and sequence length are dynamic - "sampled_flags": 0, # batch dimension is dynamic - "valid_mask": 0, # batch dimension is dynamic + "num_tokens_no_spec": 0, + "token_ids_gpu": [0, 1], + "combined_mask": 0, } ) class NgramGPUKernel(nn.Module): @@ -100,6 +99,10 @@ def _find_first_and_extract_all_n_parallel( all_windows = data.unfold(1, max_pattern_len, 1) # [B, num_windows, max_n] num_windows = all_windows.shape[1] window_starts = torch.arange(num_windows, device=device) + pattern_lengths = torch.arange( + min_pattern_len, max_pattern_len + 1, device=device + ) + batch_indices = torch.arange(batch_size, device=device) all_first_matches = torch.full( (batch_size, num_patterns), -1, dtype=torch.long, device=device @@ -120,17 +123,42 @@ def _find_first_and_extract_all_n_parallel( matches = (current_windows == patterns.unsqueeze(1)).all(dim=-1) # Validity check: ensure enough space for result extraction - max_valid_start = seq_lengths - pattern_len - result_len - valid_mask = window_starts <= max_valid_start.unsqueeze(1) + max_valid_pattern_start = seq_lengths - pattern_len - result_len + pattern_start_positions = window_starts + offset + valid_mask = pattern_start_positions <= max_valid_pattern_start.unsqueeze(1) final_matches = matches & valid_mask + # Handle prefix positions that fall before the available windows + prefix_positions = torch.arange(offset, device=device) + gather_indices = prefix_positions.view(1, -1, 1) + torch.arange( + pattern_len, device=device + ).view(1, 1, -1) + gather_indices = gather_indices.clamp(min=0, max=max_seq_len - 1) + expanded_indices = gather_indices.expand(batch_size, -1, -1) + prefix_tokens = torch.gather( + data.unsqueeze(1).expand(-1, offset, -1), + 2, + expanded_indices, + ) + prefix_matches = ( + prefix_tokens == patterns.unsqueeze(1).expand(-1, offset, -1) + ).all(dim=-1) + prefix_valid_mask = prefix_positions <= max_valid_pattern_start.unsqueeze(1) + prefix_final_matches = prefix_matches & prefix_valid_mask + + combined_matches = torch.cat([prefix_final_matches, final_matches], dim=1) + start_positions = torch.cat( + [prefix_positions, pattern_start_positions], dim=0 + ) + # Find first match # (if no match, argmax returns 0, but we verify with has_match) - first_indices = torch.argmax(final_matches.int(), dim=1) - has_match = final_matches[torch.arange(batch_size), first_indices] + first_indices = torch.argmax(combined_matches.int(), dim=1) + has_match = combined_matches[batch_indices, first_indices] + match_positions = start_positions[first_indices] # Store valid match positions - all_first_matches[:, i] = torch.where(has_match, first_indices, -1) + all_first_matches[:, i] = torch.where(has_match, match_positions, -1) # Select the longest valid match, # from back to front, prioritizing longer patterns @@ -138,19 +166,19 @@ def _find_first_and_extract_all_n_parallel( best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back # Extract corresponding results - batch_idx = torch.arange(batch_size, device=device) - best_match_pos = all_first_matches[batch_idx, best_pattern_idx] + best_match_pos = all_first_matches[batch_indices, best_pattern_idx] # Handle matched cases - completely avoid data-dependent branching has_any_match = best_match_pos >= 0 + best_pattern_lengths = pattern_lengths[best_pattern_idx] + # Calculate result start positions, invalid positions will be - # clamped to valid range. Since all windows have size max_pattern_len, - # and patterns are matched at the END of windows (due to offset), - # the result starts after the full window + # clamped to valid range. We now track true start positions, so the + # result starts right after the matched n-gram result_starts = torch.where( has_any_match, - best_match_pos + max_pattern_len, + best_match_pos + best_pattern_lengths, torch.zeros_like(best_match_pos), ) @@ -177,8 +205,7 @@ def forward( self, num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU - sampled_flags: torch.Tensor, # [batch_size] bool on GPU - valid_mask: torch.Tensor, # [batch_size] bool on GPU + combined_mask: torch.Tensor, # [batch_size] bool on GPU ) -> torch.Tensor: """ Forward pass for N-gram proposal using GPU tensor operations. @@ -189,33 +216,23 @@ def forward( Args: num_tokens_no_spec: Number of tokens for each sequence [batch_size] token_ids_gpu: Token IDs [batch_size, max_len] - sampled_flags: Whether each sequence has sampled tokens [batch_size] - valid_mask: Whether each sequence is valid for spec decode [batch_size] + combined_mask: Whether each sequence is valid for spec decode [batch_size] + batch_size: Deprecated parameter, will be inferred from tensor shape Returns: draft_tokens: [batch_size, k] on GPU """ - assert token_ids_gpu.device == self.device - assert num_tokens_no_spec.device == self.device - assert sampled_flags.device == self.device - assert valid_mask.device == self.device - - # Compute combined mask for valid sequences - combined_mask = ( - sampled_flags - & valid_mask - & (num_tokens_no_spec < self.max_model_len) - & (num_tokens_no_spec >= self.min_n) - ) - batch_size = token_ids_gpu.size(0) device = token_ids_gpu.device + # Infer batch_size from the input tensor shape to maintain dynamic shape + actual_batch_size = token_ids_gpu.shape[0] + # Initialize output tensor - torch.compile will optimize this allocation # NOTE(patchy): Do NOT pre-allocate this as a buffer # it would break torch.compile draft_tokens = torch.zeros( - (batch_size, self.k), dtype=torch.int32, device=device + (actual_batch_size, self.k), dtype=torch.int32, device=device ) results = self._find_first_and_extract_all_n_parallel( @@ -226,8 +243,10 @@ def forward( result_len=self.k, ) - # Apply combined mask to results - draft_tokens = torch.where(combined_mask.unsqueeze(1), results, draft_tokens) + # Apply combined mask to results. Expand mask explicitly to avoid + # relying on broadcasting behavior that can confuse torch.compile. + mask = combined_mask.unsqueeze(1).expand(-1, self.k) + draft_tokens = torch.where(mask, results, draft_tokens) return draft_tokens @@ -295,9 +314,16 @@ def _dummy_run(self): device=self.device, ) + combined_mask = ( + sampled_flags + & valid_mask + & (num_tokens < self.max_model_len) + & (num_tokens >= self.min_n) + ) + for _ in range(3): with set_forward_context(None, self.vllm_config): - _ = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask) + _ = self.kernel(num_tokens, token_ids, combined_mask) def _generate_dummy_data( self, @@ -366,12 +392,23 @@ def propose( sampled_flags: torch.Tensor, # [batch_size] bool on GPU valid_mask: torch.Tensor, # [batch_size] bool on GPU ) -> torch.Tensor: + assert token_ids_gpu.device == self.device + assert num_tokens_no_spec.device == self.device + assert sampled_flags.device == self.device + assert valid_mask.device == self.device + with set_forward_context(None, self.vllm_config): + combined_mask = ( + sampled_flags + & valid_mask + & (num_tokens_no_spec < self.max_model_len) + & (num_tokens_no_spec >= self.min_n) + ) + return self.kernel( num_tokens_no_spec, token_ids_gpu, - sampled_flags, - valid_mask, + combined_mask, ) def prepare_next_token_ids_cpu( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d1acdf7fed06..1a7e8f2fd976 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3141,10 +3141,9 @@ def propose_draft_token_ids(sampled_token_ids): assert isinstance(self.drafter, NgramProposerGPU) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: - # Fast path: GPU-only operation when input fits in drafter propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: - # Slow path: prepare tokens with async transfer + assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count, _ = ( self.drafter.prepare_next_token_ids_padded( spec_decode_common_attn_metadata, From 25d36b1742dc2d8b13d39cc7a1ab2d1ca5b65349 Mon Sep 17 00:00:00 2001 From: PatchouliTaisa Date: Wed, 26 Nov 2025 16:18:02 +0800 Subject: [PATCH 7/7] fix docs bug Signed-off-by: PatchouliTaisa --- vllm/v1/spec_decode/ngram_proposer_gpu.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer_gpu.py b/vllm/v1/spec_decode/ngram_proposer_gpu.py index 29366f3ca486..bf16d41a692b 100644 --- a/vllm/v1/spec_decode/ngram_proposer_gpu.py +++ b/vllm/v1/spec_decode/ngram_proposer_gpu.py @@ -203,9 +203,9 @@ def _find_first_and_extract_all_n_parallel( def forward( self, - num_tokens_no_spec: torch.Tensor, # [batch_size] on GPU - token_ids_gpu: torch.Tensor, # [batch_size, max_len] on GPU - combined_mask: torch.Tensor, # [batch_size] bool on GPU + num_tokens_no_spec: torch.Tensor, + token_ids_gpu: torch.Tensor, + combined_mask: torch.Tensor, ) -> torch.Tensor: """ Forward pass for N-gram proposal using GPU tensor operations. @@ -217,7 +217,6 @@ def forward( num_tokens_no_spec: Number of tokens for each sequence [batch_size] token_ids_gpu: Token IDs [batch_size, max_len] combined_mask: Whether each sequence is valid for spec decode [batch_size] - batch_size: Deprecated parameter, will be inferred from tensor shape Returns: draft_tokens: [batch_size, k] on GPU