Skip to content

Conversation

kahyunnam
Copy link
Contributor

@kahyunnam kahyunnam commented Oct 13, 2025

📌 Description

Generalize the existing MLA RoPE+Q fused kernels to support GQA/MHA problem shapes.

🔍 Test Results

pytest -v tests/attention/test_rope.py::test_generalized_rope_quantize

============================================================================= 312 passed in 3.93s ==============================================================================

python benchmarks/bench_rope_quantize_fp8.py

Running MLA benchmark...
Running GQA benchmark...
Running MHA benchmark...

=== Summary Table ===
Tokens   MLA (ms)   GQA (ms)   MHA (ms)
----------------------------------------
1        0.00225    0.00174    0.00164
2        0.00225    0.00164    0.00174
4        0.00236    0.00164    0.00174
8        0.00256    0.00174    0.00184
16       0.00287    0.00184    0.00195
32       0.00440    0.00195    0.00215
64       0.00748    0.00215    0.00246
128      0.01270    0.00256    0.00307
256      0.02304    0.00348    0.00461
384      0.03799    0.00440    0.00594
512      0.17039    0.00532    0.00748
768      0.28539    0.00748    0.00993

Configuration details:
  MLA: 128 Q heads, 1 K head, 64+512 dims
  GQA: 32 Q heads, 8 K heads, 64+64 dims
  MHA: 32 Q heads, 32 K heads, 64+64 dims

Plot files saved to current directory:
  mla-rope-benchmark.png
  gqa-rope-benchmark.png
  mha-rope-benchmark.png

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

@kahyunnam kahyunnam changed the title [not ready for review! draft.] MLA RoPE + quantization kernel generalization for MHA / GQA MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA Oct 14, 2025
@kahyunnam kahyunnam marked this pull request as ready for review October 14, 2025 04:48
@kahyunnam kahyunnam enabled auto-merge (squash) October 14, 2025 17:37
@kahyunnam kahyunnam requested a review from yzh119 October 14, 2025 17:40
@pavanimajety
Copy link
Contributor

@nvpohanh for another set of eyes.

TensorView q_nope_out, TensorView k_nope_out, TensorView cos_sin_cache,
TensorView pos_ids, double quant_scale_q, double quant_scale_kv,
bool interleave) {
void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A doc string explaining what the method does would be helpful here.

if __name__ == "__main__":
# Run all benchmarks and generate individual plots
print("Running MLA benchmark...")
benchmark_mla.run(print_data=False, show_plots=True, save_path=".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time permitting it would be nice to compare against existing Flashinfer Rope + Quant in torch native as baseline for these measurements

mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"),
)
def _mla_rope_quantize(
def _rope_quantize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Please add doc string here as well.

Copy link
Contributor

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, mostly just nits for documentation and benchmark. Thanks for the effort!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants