Skip to content

Conversation

@ReyJerry
Copy link

This PR implements the missing Triton kernel optimizations mentioned in the comments of moba_attn_varlen, providing a fully functional, high-performance Triton-based implementation in moba_triton.py. This serves as a drop-in replacement for moba_efficient.py with significant speedups.

Related Issue:

Fixes #37

Description:

As noted in Issue #37, the original moba_efficient.py implementation relied on standard PyTorch operations despite comments referencing Triton kernels. This resulted in suboptimal performance due to high memory I/O overheads from materializing large intermediate tensors (e.g., the gating score matrix) and frequent kernel launches.

This PR introduces moba_triton.py, which replaces key bottlenecks with custom fused Triton kernels. Specifically, it fuses the gating score calculation, causal masking, and Top-K selection into a single kernel, and implements a fused merge-softmax kernel to combine Self-Attention and MoBA-Attention outputs in one pass.

This implementation significantly reduces training time and memory footprint while maintaining the exact same interface and autograd compatibility as the original code.

Changes:

New File: moba_triton.py

  • Fused Chunk Mean Kernel (_chunk_mean_kernel):
    • Replaced the memory-intensive view().mean() operation.
    • Computes the mean of Key vectors within chunks without materializing intermediate reshaped tensors.
  • Fused Gating & Top-K Kernel (_chunk_topk_kernel):
    • Replaced the torch.einsum + torch.topk pipeline.
    • Fuses the dot product (gating score), causal masking (masking future chunks/cross-batch), and Top-K selection.
    • Avoids instantiating the massive [Batch, Head, Chunk, Seq] score matrix, reducing peak memory usage.
  • Fused Merge Softmax Kernel (_fused_merge_softmax_kernel):
    • Replaced the multi-step PyTorch LogSumExp (LSE) reduction and output merging.
    • Combines Self-Attention and MoBA-Attention outputs and LSEs in a single pass using on-chip SRAM, drastically reducing global memory reads/writes.
  • Optimized Backward Pass (_gather_moba_backward_inputs_kernel):
    • Added a custom kernel to efficiently gather gradients, outputs, and LSE scores needed for the sparse MoBA branch during backpropagation.
  • Layout Optimization:
    • Refactored index calculations to work directly with [Seq, Head, Dim] layout, removing expensive rearrange and transpose operations found in the original code.

Environment / Testing:

  • GPU: NVIDIA RTX A6000
  • Torch: 2.7.0+cu128
  • Triton: 3.3.0

Performance Benchmark:
Training a 50M parameter model on a 500M dataset for 1000 steps on a single RTX A6000:

Implementation Training Time Speedup
Original (moba_efficient.py) ~85 minutes 1.0x
Triton Optimized (moba_triton.py) ~30 minutes ~2.8x

Verified that the new implementation produces correct gradients and loss convergence matches the baseline.

Signed-off-by: BeiSheng <132643639+ReyJerry@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Question: Triton Kernel Implementation in MoBA

1 participant