From 4e3d5b5d4e45519068dcfa468479b81681fc96b2 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Fri, 7 Nov 2025 14:24:10 +0000 Subject: [PATCH 1/2] [Bugfix] Spec decode + structured output + spec model max len edge case Signed-off-by: Andy Lo --- tests/v1/spec_decode/test_max_len.py | 33 ++++++++++++++++++++++++--- vllm/v1/core/sched/scheduler.py | 12 +++++----- vllm/v1/structured_output/__init__.py | 2 +- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index bc779f6bd9c4..fa1d0437f7c7 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -7,6 +7,7 @@ from tests.utils import get_attn_backend_list_based_on_platform from vllm import LLM, SamplingParams from vllm.platforms import current_platform +from vllm.sampling_params import StructuredOutputsParams _PROMPTS = [ "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", @@ -56,8 +57,34 @@ def test_eagle_max_len( "method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": num_speculative_tokens, + "max_model_len": 80, }, - max_model_len=100, + max_model_len=200, ) - sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) - llm.generate(_PROMPTS, sampling_params) + sampling_params = SamplingParams(max_tokens=200, ignore_eos=True) + outputs = llm.generate(_PROMPTS, sampling_params) + for o in outputs: + assert o.outputs[0].finish_reason == "length", ( + "This test is only meaningful if the output " + "is truncated due to max length" + ) + + sampling_params = SamplingParams( + max_tokens=200, + structured_outputs=StructuredOutputsParams( + regex="^" + "a b c d e " * 15 + "$" + ), + ) + output = llm.generate(_PROMPTS, sampling_params) + for o in output: + assert o.prompt_token_ids is not None + assert ( + len(o.prompt_token_ids) + < 80 + < len(o.prompt_token_ids) + len(o.outputs[0].token_ids) + < 200 + ), ( + "This test is only meaningful if the output " + "is longer than the eagle max length" + ) + assert o.outputs[0].text == "a b c d e " * 15 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f558306e3b2f..0a8fd9f196da 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -321,10 +321,13 @@ def schedule(self) -> SchedulerOutput: ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. - del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids + request.spec_token_ids[:num_scheduled_spec_tokens] ) + # After this step the speculative tokens will no longer be + # valid because the structured output state would have advanced. + # Clear the list to avoid reusing them. + request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: @@ -1149,10 +1152,7 @@ def update_draft_token_ids( continue # Add newly generated spec token ids to the request. - if not spec_token_ids: - # NOTE(woosuk): request.spec_token_ids should be updated. - request.spec_token_ids.clear() - elif self.structured_output_manager.should_advance(request): + if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index acc00526ee89..4083e86ada42 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -271,7 +271,7 @@ def grammar_bitmask( ): assert structured_output_request.grammar.accept_tokens( req_id, [token] - ) + ), (token, req_id, scheduled_spec_decode_tokens) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: From 8796d01d7e01d8c584bcb833aae77a10a700e8e3 Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Sat, 8 Nov 2025 15:20:23 +0000 Subject: [PATCH 2/2] Comments Signed-off-by: Andy Lo --- vllm/v1/core/sched/scheduler.py | 8 ++++---- vllm/v1/structured_output/__init__.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0a8fd9f196da..c17b19b58c97 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -321,12 +321,12 @@ def schedule(self) -> SchedulerOutput: ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids[:num_scheduled_spec_tokens] + request.spec_token_ids ) - # After this step the speculative tokens will no longer be - # valid because the structured output state would have advanced. - # Clear the list to avoid reusing them. + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. request.spec_token_ids = [] # Encoder-related. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4083e86ada42..029129cf1a47 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -269,9 +269,10 @@ def grammar_bitmask( and token is not None and not structured_output_request.grammar.is_terminated() ): - assert structured_output_request.grammar.accept_tokens( + accepted = structured_output_request.grammar.accept_tokens( req_id, [token] - ), (token, req_id, scheduled_spec_decode_tokens) + ) + assert accepted, (token, req_id, scheduled_spec_decode_tokens) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: