66This version uses a fully vectorized approach with unfold and argmax for
77finding the first match across all sequences in parallel.
88"""
9+
10+ import numpy as np
911import torch
1012from torch import nn
11- import numpy as np
1213
1314from vllm .compilation .decorators import support_torch_compile
14- from vllm .config import CompilationConfig , CompilationMode , CUDAGraphMode , VllmConfig
15+ from vllm .config import (
16+ CompilationConfig ,
17+ VllmConfig ,
18+ )
19+ from vllm .forward_context import set_forward_context
1520from vllm .utils .platform_utils import is_pin_memory_available
1621from vllm .v1 .attention .backends .utils import (
1722 CommonAttentionMetadata ,
1823)
24+ from vllm .v1 .utils import CpuGpuBuffer
1925from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
2026
21- from vllm .config import set_current_vllm_config
22- from vllm .forward_context import set_forward_context
23- from vllm .v1 .utils import CpuGpuBuffer
2427
2528@support_torch_compile (
2629 dynamic_arg_dims = {
@@ -121,15 +124,15 @@ def _find_first_and_extract_all_n_parallel(
121124 valid_mask = window_starts <= max_valid_start .unsqueeze (1 )
122125 final_matches = matches & valid_mask
123126
124- # Find first match
127+ # Find first match
125128 # (if no match, argmax returns 0, but we verify with has_match)
126129 first_indices = torch .argmax (final_matches .int (), dim = 1 )
127130 has_match = final_matches [torch .arange (batch_size ), first_indices ]
128131
129132 # Store valid match positions
130133 all_first_matches [:, i ] = torch .where (has_match , first_indices , - 1 )
131134
132- # Select the longest valid match,
135+ # Select the longest valid match,
133136 # from back to front, prioritizing longer patterns
134137 best_pattern_idx = (all_first_matches >= 0 ).int ().flip (dims = [1 ]).argmax (dim = 1 )
135138 best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back
@@ -141,14 +144,14 @@ def _find_first_and_extract_all_n_parallel(
141144 # Handle matched cases - completely avoid data-dependent branching
142145 has_any_match = best_match_pos >= 0
143146
144- # Calculate result start positions, invalid positions will be
145- # clamped to valid range. Since all windows have size max_pattern_len,
146- # and patterns are matched at the END of windows (due to offset),
147+ # Calculate result start positions, invalid positions will be
148+ # clamped to valid range. Since all windows have size max_pattern_len,
149+ # and patterns are matched at the END of windows (due to offset),
147150 # the result starts after the full window
148151 result_starts = torch .where (
149152 has_any_match ,
150- best_match_pos + max_pattern_len ,
151- torch .zeros_like (best_match_pos ),
153+ best_match_pos + max_pattern_len ,
154+ torch .zeros_like (best_match_pos ),
152155 )
153156
154157 # Create gather indices
@@ -232,31 +235,30 @@ def load_model(self, *args, **kwargs):
232235 """No model to load for N-gram proposer."""
233236 pass
234237
238+
235239class NgramProposerGPU :
236240 def __init__ (self , vllm_config : VllmConfig , device : torch .device , runner = None ):
237241 assert vllm_config .speculative_config is not None
238242 assert vllm_config .speculative_config .prompt_lookup_min is not None
239243 assert vllm_config .speculative_config .prompt_lookup_max is not None
240244
241245 compilation_config = CompilationConfig (
242- level = 3 ,
246+ level = 3 ,
243247 custom_ops = ["none" ],
244248 splitting_ops = [],
245249 compile_sizes = [],
246250 inductor_compile_config = {
247- "enable_auto_functionalized_v2" : False ,
248- "max_autotune" : True ,
249- "aggressive_fusion" : True ,
250- "triton.autotune_pointwise" : True ,
251- "coordinate_descent_tuning" : True ,
252- "use_mixed_mm" : False ,
251+ "enable_auto_functionalized_v2" : False ,
252+ "max_autotune" : True ,
253+ "aggressive_fusion" : True ,
254+ "triton.autotune_pointwise" : True ,
255+ "coordinate_descent_tuning" : True ,
256+ "use_mixed_mm" : False ,
253257 },
254258 use_cudagraph = False ,
255259 )
256260
257- self .vllm_config = VllmConfig (
258- compilation_config = compilation_config
259- )
261+ self .vllm_config = VllmConfig (compilation_config = compilation_config )
260262
261263 self .min_n = vllm_config .speculative_config .prompt_lookup_min
262264 self .max_n = vllm_config .speculative_config .prompt_lookup_max
@@ -266,36 +268,36 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
266268 self .vocab_size = vllm_config .model_config .get_vocab_size ()
267269 self .device = device
268270
269- with set_current_vllm_config (self .vllm_config , check_compile = False ):
270- self .kernel = NgramGPUKernel (vllm_config = vllm_config , prefix = "ngram_gpu_kernel" , device = device )
271- self .device = device
272- self .kernel .to (device )
273- self .kernel .eval ()
274- max_batch_size = vllm_config .scheduler_config .max_num_seqs
275- self .backup_next_token_ids = CpuGpuBuffer (
276- max_batch_size ,
277- dtype = torch .int32 ,
278- pin_memory = is_pin_memory_available (),
279- device = device ,
280- with_numpy = True ,
281- )
271+ self .kernel = NgramGPUKernel (
272+ vllm_config = vllm_config , prefix = "ngram_gpu_kernel" , device = device
273+ )
274+ self .device = device
275+ self .kernel .to (device )
276+ self .kernel .eval ()
277+ max_batch_size = vllm_config .scheduler_config .max_num_seqs
278+ self .backup_next_token_ids = CpuGpuBuffer (
279+ max_batch_size ,
280+ dtype = torch .int32 ,
281+ pin_memory = is_pin_memory_available (),
282+ device = device ,
283+ with_numpy = True ,
284+ )
282285
283- self ._dummy_run ()
286+ self ._dummy_run ()
284287
285288 def _dummy_run (self ):
286- with set_current_vllm_config (self .vllm_config , check_compile = False ):
287- token_ids , num_tokens , sampled_flags , valid_mask = self ._generate_dummy_data (
288- batch_size = self .max_num_seqs ,
289- max_seq_len = min (self .max_model_len , 1024 ),
290- vocab_size = self .vocab_size ,
291- pattern_len = self .k ,
292- repetition_rate = 0.5 ,
293- device = self .device
294- )
289+ token_ids , num_tokens , sampled_flags , valid_mask = self ._generate_dummy_data (
290+ batch_size = self .max_num_seqs ,
291+ max_seq_len = min (self .max_model_len , 1024 ),
292+ vocab_size = self .vocab_size ,
293+ pattern_len = self .k ,
294+ repetition_rate = 0.5 ,
295+ device = self .device ,
296+ )
295297
296- for _ in range (3 ):
297- with set_forward_context (None , self .vllm_config ):
298- output = self .kernel (num_tokens , token_ids , sampled_flags , valid_mask )
298+ for _ in range (3 ):
299+ with set_forward_context (None , self .vllm_config ):
300+ _ = self .kernel (num_tokens , token_ids , sampled_flags , valid_mask )
299301
300302 def _generate_dummy_data (
301303 self ,
@@ -325,15 +327,13 @@ def _generate_dummy_data(
325327 """
326328 # Generate random token IDs
327329 token_ids = torch .randint (
328- 0 , vocab_size , (batch_size , max_seq_len ),
329- dtype = torch .int32 , device = device
330+ 0 , vocab_size , (batch_size , max_seq_len ), dtype = torch .int32 , device = device
330331 )
331332
332333 # Generate random sequence lengths
333334 min_len = max (pattern_len * 2 + 3 , max_seq_len // 2 )
334335 num_tokens = torch .randint (
335- min_len , max_seq_len , (batch_size ,),
336- dtype = torch .int32 , device = device
336+ min_len , max_seq_len , (batch_size ,), dtype = torch .int32 , device = device
337337 )
338338
339339 # Inject n-gram repetitions using the tail pattern of each sequence
@@ -349,8 +349,9 @@ def _generate_dummy_data(
349349 if tgt_pos == src_pos :
350350 continue
351351
352- token_ids [i , tgt_pos :tgt_pos + pattern_len ] = \
353- token_ids [i , src_pos :src_pos + pattern_len ].clone ()
352+ token_ids [i , tgt_pos : tgt_pos + pattern_len ] = token_ids [
353+ i , src_pos : src_pos + pattern_len
354+ ].clone ()
354355
355356 # All sequences have sampled tokens and are valid
356357 sampled_flags = torch .ones (batch_size , dtype = torch .bool , device = device )
@@ -365,9 +366,13 @@ def propose(
365366 sampled_flags : torch .Tensor , # [batch_size] bool on GPU
366367 valid_mask : torch .Tensor , # [batch_size] bool on GPU
367368 ) -> torch .Tensor :
368- with set_current_vllm_config (self .vllm_config , check_compile = False ):
369- with set_forward_context (None , self .vllm_config ):
370- return self .kernel (num_tokens_no_spec , token_ids_gpu , sampled_flags , valid_mask )
369+ with set_forward_context (None , self .vllm_config ):
370+ return self .kernel (
371+ num_tokens_no_spec ,
372+ token_ids_gpu ,
373+ sampled_flags ,
374+ valid_mask ,
375+ )
371376
372377 def prepare_next_token_ids_cpu (
373378 self ,
@@ -471,5 +476,4 @@ def prepare_next_token_ids_padded(
471476 return next_token_ids , valid_sampled_tokens_count , valid_sampled_token_ids_gpu
472477
473478 def load_model (self , * args , ** kwargs ):
474- with set_current_vllm_config (self .vllm_config , check_compile = False ):
475- self .kernel .load_model (* args , ** kwargs )
479+ self .kernel .load_model (* args , ** kwargs )
0 commit comments