1515from ...models .utils import check_outputs_equal
1616
1717MODEL = "Qwen/Qwen3-0.6B"
18- MTP_MODEL = "XiaomiMiMo/MiMo-7B-Base "
18+ MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct "
1919
2020
2121first_prompt = (
2929
3030default_params = dict (
3131 temperature = 0.0 , # greedy
32- max_tokens = 20 ,
32+ max_tokens = 23 ,
33+ min_tokens = 18 ,
3334)
3435
3536
@@ -69,24 +70,19 @@ def test_without_spec_decoding(
6970 (True , "uni" , True , None , True ),
7071 ]
7172
72- run_tests (
73- monkeypatch ,
74- MODEL ,
75- test_configs ,
76- test_sampling_params ,
77- )
73+ run_tests (monkeypatch , MODEL , test_configs , test_sampling_params )
7874
7975
80- @pytest .mark .skip ("MTP model too big to run in fp32 in CI" )
8176def test_with_spec_decoding (monkeypatch : pytest .MonkeyPatch ):
8277 """Test consistency and acceptance rates with some different combos of
8378 preemption, executor, async scheduling, prefill chunking,
8479 spec decoding model length.
8580 """
8681
8782 spec_config = {
88- "method" : "mtp " ,
83+ "method" : "eagle3 " ,
8984 "num_speculative_tokens" : 2 ,
85+ "model" : "nm-testing/Llama3_2_1B_speculator.eagle3" ,
9086 }
9187 spec_config_short = spec_config | {"max_model_len" : 50 }
9288
@@ -106,12 +102,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
106102 (True , "uni" , True , spec_config_short , True ),
107103 ]
108104
109- run_tests (
110- monkeypatch ,
111- MTP_MODEL ,
112- test_configs ,
113- [{}],
114- )
105+ run_tests (monkeypatch , MTP_MODEL , test_configs , [{}])
115106
116107
117108@dynamo_config .patch (cache_size_limit = 16 )
@@ -182,15 +173,13 @@ def run_tests(
182173 and test_acceptance_rate is not None
183174 ):
184175 if "spec_mml=None" in test_config :
185- # because the acceptance rate can vary, we use a looser
186- # tolerance here.
187176 assert (
188177 pytest .approx (test_acceptance_rate , rel = 5e-2 )
189178 == base_acceptance_rate
190179 )
191180 else :
192181 # Currently the reported acceptance rate is expected to be
193- # lower when we skip drafting altogether.
182+ # lower when we sometimes skip drafting altogether.
194183 assert test_acceptance_rate > 0.05
195184 print (
196185 f"PASSED: config=[{ test_config } ], params={ params } "
@@ -220,6 +209,7 @@ def run_test(
220209):
221210 spec_decoding = spec_config is not None
222211 cache_arg : dict [str , Any ] = (
212+ # Force preemptions
223213 dict (num_gpu_blocks_override = 32 )
224214 if test_preemption
225215 else dict (gpu_memory_utilization = 0.9 )
@@ -238,6 +228,7 @@ def run_test(
238228 model ,
239229 max_model_len = 512 ,
240230 enable_chunked_prefill = test_prefill_chunking ,
231+ # Force prefill chunking
241232 max_num_batched_tokens = 48 if test_prefill_chunking else None ,
242233 # enforce_eager=True,
243234 async_scheduling = async_scheduling ,
@@ -255,10 +246,7 @@ def run_test(
255246 results .append (
256247 vllm_model .generate (
257248 example_prompts ,
258- sampling_params = SamplingParams (
259- ** default_params ,
260- ** override_params ,
261- ),
249+ sampling_params = SamplingParams (** default_params , ** override_params ),
262250 return_logprobs = True ,
263251 )
264252 )
@@ -270,9 +258,7 @@ def run_test(
270258
271259 if test_preemption :
272260 preemptions = _get_count (
273- metrics_before ,
274- metrics_after ,
275- "vllm:num_preemptions" ,
261+ metrics_before , metrics_after , "vllm:num_preemptions"
276262 )
277263 assert preemptions > 0 , "preemption test had no preemptions"
278264
0 commit comments