-
Notifications
You must be signed in to change notification settings - Fork 529
MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA #1924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fa658e2
to
bd2a338
Compare
bd2a338
to
0fc5d11
Compare
@nvpohanh for another set of eyes. |
TensorView q_nope_out, TensorView k_nope_out, TensorView cos_sin_cache, | ||
TensorView pos_ids, double quant_scale_q, double quant_scale_kv, | ||
bool interleave) { | ||
void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A doc string explaining what the method does would be helpful here.
if __name__ == "__main__": | ||
# Run all benchmarks and generate individual plots | ||
print("Running MLA benchmark...") | ||
benchmark_mla.run(print_data=False, show_plots=True, save_path=".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Time permitting it would be nice to compare against existing Flashinfer Rope + Quant in torch native as baseline for these measurements
mutates_args=("q_rope_out", "k_rope_out", "q_nope_out", "k_nope_out"), | ||
) | ||
def _mla_rope_quantize( | ||
def _rope_quantize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Please add doc string here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, mostly just nits for documentation and benchmark. Thanks for the effort!
📌 Description
Generalize the existing MLA RoPE+Q fused kernels to support GQA/MHA problem shapes.
🔍 Test Results
pytest -v tests/attention/test_rope.py::test_generalized_rope_quantize
python benchmarks/bench_rope_quantize_fp8.py
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).