[Optimization] Implement Triton kernels for MoBA (~2.8x speedup) #39
+1,005
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements the missing Triton kernel optimizations mentioned in the comments of
moba_attn_varlen, providing a fully functional, high-performance Triton-based implementation inmoba_triton.py. This serves as a drop-in replacement formoba_efficient.pywith significant speedups.Related Issue:
Fixes #37
Description:
As noted in Issue #37, the original
moba_efficient.pyimplementation relied on standard PyTorch operations despite comments referencing Triton kernels. This resulted in suboptimal performance due to high memory I/O overheads from materializing large intermediate tensors (e.g., the gating score matrix) and frequent kernel launches.This PR introduces
moba_triton.py, which replaces key bottlenecks with custom fused Triton kernels. Specifically, it fuses the gating score calculation, causal masking, and Top-K selection into a single kernel, and implements a fused merge-softmax kernel to combine Self-Attention and MoBA-Attention outputs in one pass.This implementation significantly reduces training time and memory footprint while maintaining the exact same interface and autograd compatibility as the original code.
Changes:
New File:
moba_triton.py_chunk_mean_kernel):view().mean()operation._chunk_topk_kernel):torch.einsum+torch.topkpipeline.[Batch, Head, Chunk, Seq]score matrix, reducing peak memory usage._fused_merge_softmax_kernel):_gather_moba_backward_inputs_kernel):[Seq, Head, Dim]layout, removing expensiverearrangeand transpose operations found in the original code.Environment / Testing:
Performance Benchmark:
Training a 50M parameter model on a 500M dataset for 1000 steps on a single RTX A6000:
moba_efficient.py)moba_triton.py)Verified that the new implementation produces correct gradients and loss convergence matches the baseline.