-
Notifications
You must be signed in to change notification settings - Fork 531
fix: Fix test and benchmark for trtllm-gen prefill batch size 1 #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
benchmarks/routines/attention.py
Outdated
kv_cache = torch.cat([k_fp8, v_fp8], dim=1) | ||
|
||
if batch_size == 1: | ||
# trtllm kernel requires max_q_len to be the same as the seqlen of the query when batch_size=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why qo_indptr[-1]
could be different to s_qo
, is it because we want to be compatible with cudagraphs and s_qo
will always be the maximum length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short answer is yes.
Longer answer: In a batch_size > 1
situation, the CUDA graph containing prefill.trtllm_batch_context_with_kv_cache()
can be reused with multiple sequence lengths but not when batch_size==1
. For example,
- If batch_size is 3 and we have two batches with query lengths
[100, 200, 300]
and[16, 500, 1024]
, we can sets_qo=1024
, when we construct the CUDA graph and use the same CUDA graph for the two batches. - However for batch_size=1, where we have batches of query lengths
[100]
and[1024]
, a CUDA graph must be constructed each time -- first withs_qo=100
and second withs_qo=1024
.
Not sure whether the above is a real concern at the framework level. Nevertheless, s_qo
goes in as the max_q_len
input argument where it is the max sequence length for query. We may at least want to consider whether the wording in the documentation is clear 😄
4dade1b
to
197a7a0
Compare
Hi @bkryu does upgrading to latest trtllm-gen fixing the issue? |
/bot run |
[FAILED] Pipeline #36750562: 1/17 passed |
📌 Description
Current PR fixes the test and benchmark codes IMAs when running trtllm-gen paged & ragged prefill with batch size 1 -- the issue was described in #1898
Root cause of the issue:
flashinfer.prefill.trtllm_ragged_attention_deepseek
andflashinfer.prefill.trtllm_batch_context_with_kv_cache
both requiremax_q_len
to match the length of the query when batch size is 1.Updated PR:
Issue has been addressed from the kernel-side so that the "
max_q_len
to match the length of the query when batch size is 1" is no longer required.Current PR updates trtllm-gen FMHA cubins to latest.
Description of previous solution:
Updatingmax_q_len
tocum_seq_lens_q[-1].item()
within thetrtllm_ragged_attention_deepseek
ortrtllm_batch_context_with_kv_cache
functions are not a viable option because the CPU-side synchronization breaks the deterministic and fully device-side execution required during CUDA graph capture. The workaround was thus to update the test & benchmark codes that call the trtllm prefill functions, and clearly state in the docstring that when batch_size == 1, max_q_len must match the query size.🔍 Related Issues
#1898
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes