Skip to content

[Performance]: ROPE + KV-Cache-Write + pre-attn prepare-ops fusion #24678

@alexm-redhat

Description

@alexm-redhat

Proposal to improve performance

Performance analysis of DeepSeekR1 and GPTOSS showed that there is non-trivial amount of overheads in the operations that come before the Attention kernel. Specifically, these are ROPE + KV-Cache-Write + attn elementwise/copy prepare ops (and sometimes small matrix multiplies for MLA).

Here is a concrete breakdown for DeepSeekR1-FP8 batch-size 32 on 8xB200 GPUs:

Image

Before the attention kernel call (fmha::kernel), we can see 4 triton ops (ROPE), then cache_and_concat_mla, and then a torch BMM (matrix_multiply - K-Proj) followed by a bunch of elementwise and copy. The total time these ops take is about the same as the whole attention kernel, so we propose to fuse all of them to reduce HBM traffic bandwidth, since these ops are memory-bandwidth by their nature.

A similar fusion is needed for GPTOSS as well, here a breakdown:

Image

This one is a bit simpler, since the attention kernel does not expose elementwise ops (or matmul) like in DeepSeekR1 model above. However, there is still the series of triton ops (ROPE) followed by reshape_and_cache_kernel.

NVIDIA has a fused ROPE+KV-Cache-write+attn prepare ops kernel in TRT (@pavanimajety is part of the related team) and would be good to understand the following:

  1. Does current work supports both reshape_and_cache and concat_and_cache_mla?
  2. What about quantization types? Can we do static per-tensor, dynamic per-token, and dynamic per-group (block quant)?
  3. @Lucas Wilkinson and @ProExpertProg talked about how to expose the cache and quant and there's actually some ongoing work to refactor the attention layer into multiple layers so that MLA/SDPA/etc. don't all have to go through the same unified_attention op.
  • Issue here: [Refactor]: Make an common MLAAttention Layer and custom OP #24620
  • That means concat_and_cache_mla will be exposed in the layer and will appear in the graph.
  • For quants, instead of adding more methods and layers (my original proposal), we should implement a method similar to output quant fusion. We can query the attention backend whether input quant de-fusion is supported. If yes, we add a quant to the graph outside unified attention. This is now handled in the attention layer: [torch.compile] Make Query Quantization Fusable #24914. This won't change the semantic meaning of unified_attention as much because we're just changing the dtype of the input. We might need to figure out how to expose the scale for the quant layer to the graph, if we can't make it visible in the traced graph we can just add a new custom op that extracts it from the attention layer.

PR tracking:

  • MLAAttention extraction: #25103
  • Separate kvcache from unified_attention: #25954
  • Add concat_cache_mla_rope CUDA kernel: #25774

Metadata

Metadata

Assignees

Labels

performancePerformance-related issues

Type

No type

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions