Skip to content

Commit 09b61a6

Browse files
Ronald1995njhillbenchislett
authored andcommitted
[Core] Async Scheduling X Spec Decoding Compatibility (vllm-project#24799)
Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
1 parent 7af4bba commit 09b61a6

File tree

11 files changed

+314
-98
lines changed

11 files changed

+314
-98
lines changed

tests/v1/e2e/test_async_scheduling.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ...models.utils import check_outputs_equal
1616

1717
MODEL = "Qwen/Qwen3-0.6B"
18-
MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base"
18+
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
1919

2020

2121
first_prompt = (
@@ -29,7 +29,8 @@
2929

3030
default_params = dict(
3131
temperature=0.0, # greedy
32-
max_tokens=20,
32+
max_tokens=23,
33+
min_tokens=18,
3334
)
3435

3536

@@ -69,24 +70,19 @@ def test_without_spec_decoding(
6970
(True, "uni", True, None, True),
7071
]
7172

72-
run_tests(
73-
monkeypatch,
74-
MODEL,
75-
test_configs,
76-
test_sampling_params,
77-
)
73+
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
7874

7975

80-
@pytest.mark.skip("MTP model too big to run in fp32 in CI")
8176
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
8277
"""Test consistency and acceptance rates with some different combos of
8378
preemption, executor, async scheduling, prefill chunking,
8479
spec decoding model length.
8580
"""
8681

8782
spec_config = {
88-
"method": "mtp",
83+
"method": "eagle3",
8984
"num_speculative_tokens": 2,
85+
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
9086
}
9187
spec_config_short = spec_config | {"max_model_len": 50}
9288

@@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
106102
(True, "uni", True, spec_config_short, True),
107103
]
108104

109-
run_tests(
110-
monkeypatch,
111-
MTP_MODEL,
112-
test_configs,
113-
[{}],
114-
)
105+
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
115106

116107

117108
@dynamo_config.patch(cache_size_limit=16)
@@ -182,15 +173,13 @@ def run_tests(
182173
and test_acceptance_rate is not None
183174
):
184175
if "spec_mml=None" in test_config:
185-
# because the acceptance rate can vary, we use a looser
186-
# tolerance here.
187176
assert (
188177
pytest.approx(test_acceptance_rate, rel=5e-2)
189178
== base_acceptance_rate
190179
)
191180
else:
192181
# Currently the reported acceptance rate is expected to be
193-
# lower when we skip drafting altogether.
182+
# lower when we sometimes skip drafting altogether.
194183
assert test_acceptance_rate > 0.05
195184
print(
196185
f"PASSED: config=[{test_config}], params={params}"
@@ -220,6 +209,7 @@ def run_test(
220209
):
221210
spec_decoding = spec_config is not None
222211
cache_arg: dict[str, Any] = (
212+
# Force preemptions
223213
dict(num_gpu_blocks_override=32)
224214
if test_preemption
225215
else dict(gpu_memory_utilization=0.9)
@@ -238,6 +228,7 @@ def run_test(
238228
model,
239229
max_model_len=512,
240230
enable_chunked_prefill=test_prefill_chunking,
231+
# Force prefill chunking
241232
max_num_batched_tokens=48 if test_prefill_chunking else None,
242233
# enforce_eager=True,
243234
async_scheduling=async_scheduling,
@@ -255,10 +246,7 @@ def run_test(
255246
results.append(
256247
vllm_model.generate(
257248
example_prompts,
258-
sampling_params=SamplingParams(
259-
**default_params,
260-
**override_params,
261-
),
249+
sampling_params=SamplingParams(**default_params, **override_params),
262250
return_logprobs=True,
263251
)
264252
)
@@ -270,9 +258,7 @@ def run_test(
270258

271259
if test_preemption:
272260
preemptions = _get_count(
273-
metrics_before,
274-
metrics_after,
275-
"vllm:num_preemptions",
261+
metrics_before, metrics_after, "vllm:num_preemptions"
276262
)
277263
assert preemptions > 0, "preemption test had no preemptions"
278264

vllm/config/speculative.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import ast
55
import hashlib
6-
from typing import TYPE_CHECKING, Any, Literal
6+
from typing import TYPE_CHECKING, Any, Literal, get_args
77

88
from pydantic import Field, SkipValidation, model_validator
99
from pydantic.dataclasses import dataclass
@@ -29,31 +29,25 @@
2929

3030
logger = init_logger(__name__)
3131

32-
SpeculativeMethod = Literal[
33-
"ngram",
34-
"eagle",
35-
"eagle3",
36-
"medusa",
37-
"mlp_speculator",
38-
"draft_model",
39-
"deepseek_mtp",
40-
"ernie_mtp",
41-
"qwen3_next_mtp",
42-
"mimo_mtp",
43-
"longcat_flash_mtp",
44-
"pangu_ultra_moe_mtp",
45-
"mtp",
46-
"suffix",
47-
]
48-
MTP_MODEL_TYPES = (
32+
MTPModelTypes = Literal[
4933
"deepseek_mtp",
5034
"mimo_mtp",
5135
"glm4_moe_mtp",
5236
"ernie_mtp",
5337
"qwen3_next_mtp",
5438
"longcat_flash_mtp",
39+
"mtp",
5540
"pangu_ultra_moe_mtp",
56-
)
41+
]
42+
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
43+
SpeculativeMethod = Literal[
44+
"ngram",
45+
"medusa",
46+
"mlp_speculator",
47+
"draft_model",
48+
"suffix",
49+
EagleModelTypes,
50+
]
5751

5852

5953
@config
@@ -244,7 +238,7 @@ def __post_init__(self):
244238
# can not be detected, it will be considered as the "draft_model" by
245239
# default.
246240

247-
if self.method in MTP_MODEL_TYPES:
241+
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
248242
logger.warning(
249243
"method `%s` is deprecated and replaced with mtp.", self.method
250244
)
@@ -361,7 +355,9 @@ def __post_init__(self):
361355
self.method = "medusa"
362356
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
363357
self.method = "mlp_speculator"
364-
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
358+
elif self.draft_model_config.hf_config.model_type in get_args(
359+
MTPModelTypes
360+
):
365361
self.method = "mtp"
366362
if self.num_speculative_tokens > 1:
367363
logger.warning(

vllm/config/vllm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from datetime import datetime
1515
from functools import lru_cache
1616
from pathlib import Path
17-
from typing import TYPE_CHECKING, Any, TypeVar
17+
from typing import TYPE_CHECKING, Any, TypeVar, get_args
1818

1919
import torch
2020
from pydantic import ConfigDict, Field, model_validator
2121
from pydantic.dataclasses import dataclass
2222

2323
import vllm.envs as envs
24+
from vllm.config.speculative import EagleModelTypes
2425
from vllm.logger import enable_trace_function_call, init_logger
2526
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
2627
from vllm.utils import random_uuid
@@ -374,10 +375,22 @@ def __post_init__(self):
374375
"Async scheduling is not yet compatible with "
375376
"pipeline_parallel_size > 1."
376377
)
378+
# Currently, async scheduling only support eagle speculative
379+
# decoding.
377380
if self.speculative_config is not None:
378-
raise ValueError(
379-
"Async scheduling is not yet compatible with speculative decoding."
380-
)
381+
if self.speculative_config.method not in get_args(EagleModelTypes):
382+
raise ValueError(
383+
"Currently, async scheduling is only supported "
384+
"with EAGLE/MTP kind of speculative decoding"
385+
)
386+
if self.speculative_config.disable_padded_drafter_batch:
387+
raise ValueError(
388+
"async scheduling for EAGLE/MTP kind of speculative "
389+
"decoding is enabled, but disable_padded_drafter_batch=True "
390+
"disable_padded_drafter_batch=True is not supported for "
391+
"this situation now. please set "
392+
"disable_padded_drafter_batch=Fasle"
393+
)
381394
if not executor_supports_async_sched:
382395
raise ValueError(
383396
"Currently, async scheduling only supports `mp`, `uni`, or "

vllm/v1/core/sched/async_scheduler.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,25 @@ def _update_after_schedule(
1616
) -> None:
1717
super()._update_after_schedule(scheduler_output)
1818
pending_structured_output_tokens = False
19+
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
1920
for req_id in scheduler_output.num_scheduled_tokens:
2021
request = self.requests[req_id]
2122
pending_structured_output_tokens |= (
2223
request.use_structured_output and request.num_output_placeholders > 0
2324
)
25+
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
2426
if (
2527
request.num_computed_tokens
26-
== request.num_tokens + request.num_output_placeholders
28+
== request.num_tokens
29+
+ request.num_output_placeholders
30+
+ cur_num_spec_tokens
2731
):
28-
# The request will generate a new token in this scheduling step.
29-
# TODO(woosuk): Support speculative decoding.
30-
request.num_output_placeholders += 1
32+
# The request will generate a new token plus num_spec_tokens
33+
# in this scheduling step.
34+
request.num_output_placeholders += 1 + cur_num_spec_tokens
35+
# Add placeholders for the new tokens in spec_token_ids.
36+
# Wwe will update the actual spec token ids in the worker process.
37+
request.spec_token_ids = [-1] * self.num_spec_tokens
3138

3239
scheduler_output.pending_structured_output_tokens = (
3340
pending_structured_output_tokens

vllm/v1/core/sched/scheduler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,10 @@ def schedule(self) -> SchedulerOutput:
348348
# Speculative decode related.
349349
if request.spec_token_ids:
350350
num_scheduled_spec_tokens = (
351-
num_new_tokens + request.num_computed_tokens - request.num_tokens
351+
num_new_tokens
352+
+ request.num_computed_tokens
353+
- request.num_tokens
354+
- request.num_output_placeholders
352355
)
353356
if num_scheduled_spec_tokens > 0:
354357
# Trim spec_token_ids list to num_scheduled_spec_tokens.
@@ -1024,7 +1027,12 @@ def update_from_output(
10241027
# tokens and rejections. If some tokens are rejected,
10251028
# num_computed_tokens is decreased by the number of rejected
10261029
# tokens.
1027-
request.num_computed_tokens -= num_rejected
1030+
if request.num_computed_tokens > 0:
1031+
request.num_computed_tokens -= num_rejected
1032+
# If async scheduling, num_output_placeholders also includes
1033+
# the scheduled spec tokens count and so is similarly adjusted.
1034+
if request.num_output_placeholders > 0:
1035+
request.num_output_placeholders -= num_rejected
10281036
spec_decoding_stats = self.make_spec_decoding_stats(
10291037
spec_decoding_stats,
10301038
num_draft_tokens=num_draft_tokens,

vllm/v1/engine/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def __init__(
198198
self.step_fn = (
199199
self.step if self.batch_queue is None else self.step_with_batch_queue
200200
)
201+
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
201202

202203
# Mark the startup heap as static so that it's ignored by GC.
203204
# Reduces pause times of oldest generation collections.
@@ -341,7 +342,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
341342
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
342343

343344
def post_step(self, model_executed: bool) -> None:
344-
if self.use_spec_decode and model_executed:
345+
# When using async scheduling we can't get draft token ids in advance,
346+
# so we update draft token ids in the worker process and don't
347+
# need to update draft token ids here.
348+
if not self.async_scheduling and self.use_spec_decode and model_executed:
345349
# Take the draft token ids.
346350
draft_token_ids = self.model_executor.take_draft_token_ids()
347351
if draft_token_ids is not None:

vllm/v1/engine/processor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,23 @@ def _validate_supported_sampling_params(
150150
raise ValueError(
151151
"vLLM V1 does not support per request user provided logits processors."
152152
)
153+
# Async scheduling + spec decode currently incompatible with some
154+
# sampling parameters.
155+
if (
156+
self.vllm_config.speculative_config is not None
157+
and self.vllm_config.scheduler_config.async_scheduling
158+
and (
159+
params.frequency_penalty != 0.0
160+
or params.presence_penalty != 0.0
161+
or params.repetition_penalty != 1.0
162+
or params.bad_words_token_ids
163+
or params.structured_outputs
164+
)
165+
):
166+
raise ValueError(
167+
"async scheduling with spec decoding doesn't yet support "
168+
"penalties, bad words or structured outputs in sampling parameters."
169+
)
153170

154171
def _validate_params(
155172
self,

vllm/v1/sample/logits_processor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
# Error message when the user tries to initialize vLLM with a speculative
4242
# decoding enabled and custom logitsproces
4343
STR_SPEC_DEC_REJECTS_LOGITSPROCS = (
44-
"Custom logits processors are not supportedwhen speculative decoding is enabled."
44+
"Custom logits processors are not supported when speculative decoding is enabled."
4545
)
4646

4747
LOGITSPROCS_GROUP = "vllm.logits_processors"

vllm/v1/spec_decode/eagle.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,13 @@ def propose(
397397
positions += 1
398398
exceeds_max_model_len = positions >= self.max_model_len
399399
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
400-
400+
# For data integrity when async scheduling, we shouldn't use in place
401+
# operations in case they are modified in next step's `prepare_input`
402+
# of main model.
401403
# Increment the sequence lengths.
402404
common_attn_metadata.seq_lens += 1
403-
common_attn_metadata.seq_lens_cpu += 1
405+
# This is an out-of-place operation to avoid modifying the original tensor.
406+
common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
404407
# For the requests that exceed the max model length, we set the
405408
# sequence length to 1 to minimize their overheads in attention.
406409

vllm/v1/worker/gpu_input_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class CachedRequestState:
4646
lora_request: LoRARequest | None = None
4747
prompt_embeds: torch.Tensor | None = None
4848

49+
# Used when both async_scheduling and spec_decode are enabled.
50+
prev_num_draft_len: int = 0
51+
4952
def __post_init__(self):
5053
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
5154
self.prompt_token_ids, self.prompt_embeds

0 commit comments

Comments
 (0)