-
Notifications
You must be signed in to change notification settings - Fork 7
Description
Feature Description
Implement a memory-efficient and I/O-aware version of the attention mechanism, based on the principles of FlashAttention. This involves creating a fused OpenCL kernel that computes attention without materializing the large, intermediate [sequence_length, sequence_length]
attention score matrix. This will significantly improve the performance and reduce the memory footprint of transformer models, especially those with long sequence lengths.
Use Case
Training and running large-scale transformer models, such as GPT, with long context windows is currently a major performance bottleneck. The standard attention implementation is limited by memory bandwidth due to the need to read and write the large attention matrix to and from GPU HBM.
This feature would allow users to:
- Train larger models faster.
- Use longer sequence lengths without running out of memory.
- Make Brain4j a more viable high-performance framework for state-of-the-art NLP tasks.
Current Alternatives
The current implementation in AttentionHead.java uses the standard, textbook Scaled Dot-Product Attention. It is functionally correct but memory-inefficient. It works by chaining together several operations (matmul
, scale
, softmax
, matmul
), which explicitly creates the large intermediate score matrix. This serves as a valid, but slow, workaround for smaller models or shorter sequences.
Implementation Ideas
The core of this feature would be a new fused OpenCL kernel. The implementation should follow these general steps:
- Tiling: The Q, K, and V matrices should be partitioned into smaller blocks or tiles.
- Kernel Fusion: A single OpenCL kernel should be written to load these tiles into the GPU's fast on-chip memory (SRAM).
- Online Softmax: Within the kernel, iterate through the blocks of K and V. The softmax operation must be computed in a streaming manner, updating the normalization constant as new blocks are processed, without ever storing the full score matrix.
- Integration:
4.1) Create a newFlashAttentionHead
class that inherits fromAttentionHead
.
4.2) This new class will call the custom OpenCL kernel instead of the existing sequence of tensor operations.
4.3) TheMultiHeadAttention
class can then be updated to optionally use this new, optimized attention head.
ML Context
This feature is a standard optimization in all major Python-based ML libraries.
- Pytorch:
torch.nn.functional.scaled_dot_product_attention
automatically uses aFlashAttention
implementation on supported hardware. Thexformers
library was an early pioneer of this. - Tensorflow: Similar optimizations are available through various means.
Additional Context
This implementation would be based on the concepts from the paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra and Christopher Ré. The goal is to bring the core ideas of this paper (tiling, kernel fusion, and online softmax) into the OpenCL environment of Brain4j.