Skip to content

Commit 4534c88

Browse files
author
PatchouliTaisa
committed
python3.13 pre-commit check
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
1 parent 293e3ae commit 4534c88

File tree

2 files changed

+71
-63
lines changed

2 files changed

+71
-63
lines changed

vllm/v1/spec_decode/ngram_proposer_gpu.py

Lines changed: 63 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,24 @@
66
This version uses a fully vectorized approach with unfold and argmax for
77
finding the first match across all sequences in parallel.
88
"""
9+
10+
import numpy as np
911
import torch
1012
from torch import nn
11-
import numpy as np
1213

1314
from vllm.compilation.decorators import support_torch_compile
14-
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, VllmConfig
15+
from vllm.config import (
16+
CompilationConfig,
17+
VllmConfig,
18+
)
19+
from vllm.forward_context import set_forward_context
1520
from vllm.utils.platform_utils import is_pin_memory_available
1621
from vllm.v1.attention.backends.utils import (
1722
CommonAttentionMetadata,
1823
)
24+
from vllm.v1.utils import CpuGpuBuffer
1925
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
2026

21-
from vllm.config import set_current_vllm_config
22-
from vllm.forward_context import set_forward_context
23-
from vllm.v1.utils import CpuGpuBuffer
2427

2528
@support_torch_compile(
2629
dynamic_arg_dims={
@@ -121,15 +124,15 @@ def _find_first_and_extract_all_n_parallel(
121124
valid_mask = window_starts <= max_valid_start.unsqueeze(1)
122125
final_matches = matches & valid_mask
123126

124-
# Find first match
127+
# Find first match
125128
# (if no match, argmax returns 0, but we verify with has_match)
126129
first_indices = torch.argmax(final_matches.int(), dim=1)
127130
has_match = final_matches[torch.arange(batch_size), first_indices]
128131

129132
# Store valid match positions
130133
all_first_matches[:, i] = torch.where(has_match, first_indices, -1)
131134

132-
# Select the longest valid match,
135+
# Select the longest valid match,
133136
# from back to front, prioritizing longer patterns
134137
best_pattern_idx = (all_first_matches >= 0).int().flip(dims=[1]).argmax(dim=1)
135138
best_pattern_idx = num_patterns - 1 - best_pattern_idx # Flip back
@@ -141,14 +144,14 @@ def _find_first_and_extract_all_n_parallel(
141144
# Handle matched cases - completely avoid data-dependent branching
142145
has_any_match = best_match_pos >= 0
143146

144-
# Calculate result start positions, invalid positions will be
145-
# clamped to valid range. Since all windows have size max_pattern_len,
146-
# and patterns are matched at the END of windows (due to offset),
147+
# Calculate result start positions, invalid positions will be
148+
# clamped to valid range. Since all windows have size max_pattern_len,
149+
# and patterns are matched at the END of windows (due to offset),
147150
# the result starts after the full window
148151
result_starts = torch.where(
149152
has_any_match,
150-
best_match_pos + max_pattern_len,
151-
torch.zeros_like(best_match_pos),
153+
best_match_pos + max_pattern_len,
154+
torch.zeros_like(best_match_pos),
152155
)
153156

154157
# Create gather indices
@@ -232,31 +235,30 @@ def load_model(self, *args, **kwargs):
232235
"""No model to load for N-gram proposer."""
233236
pass
234237

238+
235239
class NgramProposerGPU:
236240
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
237241
assert vllm_config.speculative_config is not None
238242
assert vllm_config.speculative_config.prompt_lookup_min is not None
239243
assert vllm_config.speculative_config.prompt_lookup_max is not None
240244

241245
compilation_config = CompilationConfig(
242-
level=3,
246+
level=3,
243247
custom_ops=["none"],
244248
splitting_ops=[],
245249
compile_sizes=[],
246250
inductor_compile_config={
247-
"enable_auto_functionalized_v2": False,
248-
"max_autotune": True,
249-
"aggressive_fusion": True,
250-
"triton.autotune_pointwise": True,
251-
"coordinate_descent_tuning": True,
252-
"use_mixed_mm": False,
251+
"enable_auto_functionalized_v2": False,
252+
"max_autotune": True,
253+
"aggressive_fusion": True,
254+
"triton.autotune_pointwise": True,
255+
"coordinate_descent_tuning": True,
256+
"use_mixed_mm": False,
253257
},
254258
use_cudagraph=False,
255259
)
256260

257-
self.vllm_config = VllmConfig(
258-
compilation_config=compilation_config
259-
)
261+
self.vllm_config = VllmConfig(compilation_config=compilation_config)
260262

261263
self.min_n = vllm_config.speculative_config.prompt_lookup_min
262264
self.max_n = vllm_config.speculative_config.prompt_lookup_max
@@ -266,36 +268,36 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
266268
self.vocab_size = vllm_config.model_config.get_vocab_size()
267269
self.device = device
268270

269-
with set_current_vllm_config(self.vllm_config, check_compile=False):
270-
self.kernel = NgramGPUKernel(vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device)
271-
self.device = device
272-
self.kernel.to(device)
273-
self.kernel.eval()
274-
max_batch_size = vllm_config.scheduler_config.max_num_seqs
275-
self.backup_next_token_ids = CpuGpuBuffer(
276-
max_batch_size,
277-
dtype=torch.int32,
278-
pin_memory=is_pin_memory_available(),
279-
device=device,
280-
with_numpy=True,
281-
)
271+
self.kernel = NgramGPUKernel(
272+
vllm_config=vllm_config, prefix="ngram_gpu_kernel", device=device
273+
)
274+
self.device = device
275+
self.kernel.to(device)
276+
self.kernel.eval()
277+
max_batch_size = vllm_config.scheduler_config.max_num_seqs
278+
self.backup_next_token_ids = CpuGpuBuffer(
279+
max_batch_size,
280+
dtype=torch.int32,
281+
pin_memory=is_pin_memory_available(),
282+
device=device,
283+
with_numpy=True,
284+
)
282285

283-
self._dummy_run()
286+
self._dummy_run()
284287

285288
def _dummy_run(self):
286-
with set_current_vllm_config(self.vllm_config, check_compile=False):
287-
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
288-
batch_size=self.max_num_seqs,
289-
max_seq_len=min(self.max_model_len, 1024),
290-
vocab_size=self.vocab_size,
291-
pattern_len=self.k,
292-
repetition_rate=0.5,
293-
device=self.device
294-
)
289+
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
290+
batch_size=self.max_num_seqs,
291+
max_seq_len=min(self.max_model_len, 1024),
292+
vocab_size=self.vocab_size,
293+
pattern_len=self.k,
294+
repetition_rate=0.5,
295+
device=self.device,
296+
)
295297

296-
for _ in range(3):
297-
with set_forward_context(None, self.vllm_config):
298-
output = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask)
298+
for _ in range(3):
299+
with set_forward_context(None, self.vllm_config):
300+
_ = self.kernel(num_tokens, token_ids, sampled_flags, valid_mask)
299301

300302
def _generate_dummy_data(
301303
self,
@@ -325,15 +327,13 @@ def _generate_dummy_data(
325327
"""
326328
# Generate random token IDs
327329
token_ids = torch.randint(
328-
0, vocab_size, (batch_size, max_seq_len),
329-
dtype=torch.int32, device=device
330+
0, vocab_size, (batch_size, max_seq_len), dtype=torch.int32, device=device
330331
)
331332

332333
# Generate random sequence lengths
333334
min_len = max(pattern_len * 2 + 3, max_seq_len // 2)
334335
num_tokens = torch.randint(
335-
min_len, max_seq_len, (batch_size,),
336-
dtype=torch.int32, device=device
336+
min_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
337337
)
338338

339339
# Inject n-gram repetitions using the tail pattern of each sequence
@@ -349,8 +349,9 @@ def _generate_dummy_data(
349349
if tgt_pos == src_pos:
350350
continue
351351

352-
token_ids[i, tgt_pos:tgt_pos + pattern_len] = \
353-
token_ids[i, src_pos:src_pos + pattern_len].clone()
352+
token_ids[i, tgt_pos : tgt_pos + pattern_len] = token_ids[
353+
i, src_pos : src_pos + pattern_len
354+
].clone()
354355

355356
# All sequences have sampled tokens and are valid
356357
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
@@ -365,9 +366,13 @@ def propose(
365366
sampled_flags: torch.Tensor, # [batch_size] bool on GPU
366367
valid_mask: torch.Tensor, # [batch_size] bool on GPU
367368
) -> torch.Tensor:
368-
with set_current_vllm_config(self.vllm_config, check_compile=False):
369-
with set_forward_context(None, self.vllm_config):
370-
return self.kernel(num_tokens_no_spec, token_ids_gpu, sampled_flags, valid_mask)
369+
with set_forward_context(None, self.vllm_config):
370+
return self.kernel(
371+
num_tokens_no_spec,
372+
token_ids_gpu,
373+
sampled_flags,
374+
valid_mask,
375+
)
371376

372377
def prepare_next_token_ids_cpu(
373378
self,
@@ -471,5 +476,4 @@ def prepare_next_token_ids_padded(
471476
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu
472477

473478
def load_model(self, *args, **kwargs):
474-
with set_current_vllm_config(self.vllm_config, check_compile=False):
475-
self.kernel.load_model(*args, **kwargs)
479+
self.kernel.load_model(*args, **kwargs)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,9 +1019,11 @@ def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None
10191019
else:
10201020
new_req_indices.append((req_id, curr_idx))
10211021

1022-
# Case 2: Only common requests (subset or same set), may need reordering or clearing
1022+
# Case 2: Only common requests (subset or same set),
1023+
# may need reordering or clearing
10231024
if not new_req_indices:
1024-
# If indices haven't changed and it's the exact same set, already handled by Case 1
1025+
# If indices haven't changed and it's the exact same set,
1026+
# already handled by Case 1
10251027
# So here we either have reordering or a subset (some requests finished)
10261028
if not indices_match or len(common_req_indices) < len(prev_req_id_to_index):
10271029
# Need to reorder or clear finished requests
@@ -1032,13 +1034,15 @@ def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None
10321034
prev_indices, dtype=torch.long, device=self.device
10331035
)
10341036

1035-
# Create temporary tensors for scatter operation (zeros will clear unused positions)
1037+
# Create temporary tensors for scatter operation
1038+
# (zeros will clear unused positions)
10361039
temp_token_ids = torch.zeros_like(self.input_batch.token_ids_gpu_tensor)
10371040
temp_num_tokens = torch.zeros_like(
10381041
self.input_batch.num_tokens_no_spec_gpu
10391042
)
10401043

1041-
# Scatter token_ids - copy entire rows (already up-to-date from prepare_next_token_ids_padded)
1044+
# Scatter token_ids - copy entire rows
1045+
# (already up-to-date from prepare_next_token_ids_padded)
10421046
temp_token_ids[curr_indices_tensor] = (
10431047
self.input_batch.token_ids_gpu_tensor[prev_indices_tensor]
10441048
)

0 commit comments

Comments
 (0)