diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index bc779f6bd9c4..fa1d0437f7c7 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -7,6 +7,7 @@ from tests.utils import get_attn_backend_list_based_on_platform from vllm import LLM, SamplingParams from vllm.platforms import current_platform +from vllm.sampling_params import StructuredOutputsParams _PROMPTS = [ "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", @@ -56,8 +57,34 @@ def test_eagle_max_len( "method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": num_speculative_tokens, + "max_model_len": 80, }, - max_model_len=100, + max_model_len=200, ) - sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) - llm.generate(_PROMPTS, sampling_params) + sampling_params = SamplingParams(max_tokens=200, ignore_eos=True) + outputs = llm.generate(_PROMPTS, sampling_params) + for o in outputs: + assert o.outputs[0].finish_reason == "length", ( + "This test is only meaningful if the output " + "is truncated due to max length" + ) + + sampling_params = SamplingParams( + max_tokens=200, + structured_outputs=StructuredOutputsParams( + regex="^" + "a b c d e " * 15 + "$" + ), + ) + output = llm.generate(_PROMPTS, sampling_params) + for o in output: + assert o.prompt_token_ids is not None + assert ( + len(o.prompt_token_ids) + < 80 + < len(o.prompt_token_ids) + len(o.outputs[0].token_ids) + < 200 + ), ( + "This test is only meaningful if the output " + "is longer than the eagle max length" + ) + assert o.outputs[0].text == "a b c d e " * 15 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f558306e3b2f..c17b19b58c97 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -325,6 +325,9 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens[request.request_id] = ( request.spec_token_ids ) + # New spec tokens will be set in `update_draft_token_ids` before the + # next step when applicable. + request.spec_token_ids = [] # Encoder-related. if encoder_inputs_to_schedule: @@ -1149,10 +1152,7 @@ def update_draft_token_ids( continue # Add newly generated spec token ids to the request. - if not spec_token_ids: - # NOTE(woosuk): request.spec_token_ids should be updated. - request.spec_token_ids.clear() - elif self.structured_output_manager.should_advance(request): + if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index acc00526ee89..029129cf1a47 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -269,9 +269,10 @@ def grammar_bitmask( and token is not None and not structured_output_request.grammar.is_terminated() ): - assert structured_output_request.grammar.accept_tokens( + accepted = structured_output_request.grammar.accept_tokens( req_id, [token] ) + assert accepted, (token, req_id, scheduled_spec_decode_tokens) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: