Skip to content

Conversation

@PatchouliTIS
Copy link

@PatchouliTIS PatchouliTIS commented Nov 21, 2025

Purpose

This PR is based on PR #24799 aiming to implement GPU version of ngram speculative decoding and make it compatible with Async Scheduler.

Test Plan

  • Async Scheduler + NGram + Qwen3-1.7B
    Test config:
# dataset is CMU-DoG, which is an input-grounded dataset.
python3.12 -u -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8000 \
--max-num-seqs 128 \
--max-model-len 2048 \
--model Qwen/Qwen3-1.7B \
--tensor-parallel-size 1 \
--trust-remote-code \
--dtype bfloat16  \
--enable-chunked-prefill \
--disable-log-requests \
--async-scheduling \
--speculative_config '{"method": "ngram_gpu", "num_speculative_tokens": 3, "prompt_lookup_max": 2,"prompt_lookup_min": 2}'

Test Device: NVIDIA H20

Test Result

Performance

num_prompts async_ngram(tps) sync_ngram(tps) speedup
2 466 357 30.5%
8 1378 988 39.4%
16 2082 1726 20.6%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

hl475 and others added 4 commits November 24, 2025 10:58
…rs (vllm-project#29111)

Signed-off-by: Huamin Li <3ericli@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
PatchouliTaisa and others added 7 commits November 24, 2025 11:03
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@ZJY0516
Copy link
Contributor

ZJY0516 commented Nov 27, 2025

cc @njhill

# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config)

# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Can this be fixed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enabled torch compile in the ngram gpu kernel, the computational graph corresponding to ngram operator would hit a precompiled computational graph cache in the main model, leading to mismatched computational graph results. Therefore, I directly disabled the compile cache here. I tested this locally, and disabling the cache had no impact on performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume disabling the compile cache would lead to longer startup time? I'm not an expert here but maybe it's possible to add an identifier to the compile cache to avoid extraneous cache hits?

pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.token_ids_gpu_tensor = torch.zeros(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a massive buffer, and can go up to 1GB of VRAM in normal use-cases. Is there anything that can be done about this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both the ngram gpu computation and inputs preparation of ngram can benefit from the buffer, I think it is worth to maintain the buffer. BTW I'm confused about the size, considering set max_model_len as 128k, it will take approximately 1600 max_num_seqs to reach the VRAM size of 1GB, is it a normal use-cases? Besides users can specify the value of max_num_seqs and max_model_len as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Some models have larger max context length than 128K (Qwen3 has 256K, Llama4 has 1M+)
  • Deployments often have max_num_seqs between 512 and 2048. I would consider this to be a normal range.

all_token_ids = prompt_token_ids + req_state.output_token_ids
num_tokens = len(all_token_ids)
# Copy to GPU tensor
self.input_batch.token_ids_gpu_tensor[idx, :num_tokens].copy_(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this copy is copying from a device tensor, instead of a standard HtoD copy. Why is that?

),
non_blocking=True,
)
self.input_batch.num_tokens_no_spec_gpu[idx] = num_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this logic be integrated into _update_states, where num_tokens_no_spec (cpu) is maintained? That seems like it would be cleaner than recomputing it twice and copying over here. Also, we would not want to maintain two pieces of the same logic.

for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens

def _update_ngram_gpu_tensors(self, scheduler_output: "SchedulerOutput") -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any unique logic in here that is distinct from how we maintain token_ids_cpu_tensor and num_tokens_no_spec_cpu?

@support_torch_compile(
dynamic_arg_dims={
"num_tokens_no_spec": 0,
"token_ids_gpu": [0, 1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when both dims are marked as dynamic? Does it recompile for every combination of BS and SeqLen? If so, isn't that way too much compilation?

with set_forward_context(None, self.vllm_config):
_ = self.kernel(num_tokens, token_ids, combined_mask)

def _generate_dummy_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the dummy run need actual synthetic data? Can't it just use random tokens? Isn't the point of the dummy run just to initialize/record the shapes and buffers?

combined_mask = (
sampled_flags
& valid_mask
& (num_tokens_no_spec < self.max_model_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could this happen?

combined_mask,
)

def prepare_next_token_ids_cpu(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't duplicate functions like this. Refactor and reuse the code.

combined_mask,
)

def prepare_next_token_ids_cpu(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not seem to be used.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only major concern with this PR right now is maintaining the code. Between this PR and the draft model support, we are duplicating/reusing many components of the original EAGLE implementation in a naive manner. We should carefully structure the code here to implement these drafters more cleanly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants