diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index ac73ca9871..7399bd4268 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1715,7 +1715,7 @@ def plan( self._max_q_len = max_token_per_sequence else: qo_indptr_host = qo_indptr.to("cpu") - self._max_q_len = max(qo_indptr_host).item() + self._max_q_len = max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item() total_num_rows = int(qo_indptr_host[-1]) if max_sequence_kv is not None: