Skip to content

[FEATURE] Implement FlashAttention for Optimized Transformer Performance #22

@Adversing

Description

@Adversing

Feature Description

Implement a memory-efficient and I/O-aware version of the attention mechanism, based on the principles of FlashAttention. This involves creating a fused OpenCL kernel that computes attention without materializing the large, intermediate [sequence_length, sequence_length] attention score matrix. This will significantly improve the performance and reduce the memory footprint of transformer models, especially those with long sequence lengths.

Use Case

Training and running large-scale transformer models, such as GPT, with long context windows is currently a major performance bottleneck. The standard attention implementation is limited by memory bandwidth due to the need to read and write the large attention matrix to and from GPU HBM.

This feature would allow users to:

  • Train larger models faster.
  • Use longer sequence lengths without running out of memory.
  • Make Brain4j a more viable high-performance framework for state-of-the-art NLP tasks.

Current Alternatives

The current implementation in AttentionHead.java uses the standard, textbook Scaled Dot-Product Attention. It is functionally correct but memory-inefficient. It works by chaining together several operations (matmul, scale, softmax, matmul), which explicitly creates the large intermediate score matrix. This serves as a valid, but slow, workaround for smaller models or shorter sequences.

Implementation Ideas

The core of this feature would be a new fused OpenCL kernel. The implementation should follow these general steps:

  1. Tiling: The Q, K, and V matrices should be partitioned into smaller blocks or tiles.
  2. Kernel Fusion: A single OpenCL kernel should be written to load these tiles into the GPU's fast on-chip memory (SRAM).
  3. Online Softmax: Within the kernel, iterate through the blocks of K and V. The softmax operation must be computed in a streaming manner, updating the normalization constant as new blocks are processed, without ever storing the full score matrix.
  4. Integration:
    4.1) Create a new FlashAttentionHead class that inherits from AttentionHead.
    4.2) This new class will call the custom OpenCL kernel instead of the existing sequence of tensor operations.
    4.3) The MultiHeadAttention class can then be updated to optionally use this new, optimized attention head.

ML Context

This feature is a standard optimization in all major Python-based ML libraries.

  • Pytorch: torch.nn.functional.scaled_dot_product_attention automatically uses a FlashAttention implementation on supported hardware. The xformers library was an early pioneer of this.
  • Tensorflow: Similar optimizations are available through various means.

Additional Context

This implementation would be based on the concepts from the paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra and Christopher Ré. The goal is to bring the core ideas of this paper (tiling, kernel fusion, and online softmax) into the OpenCL environment of Brain4j.

Metadata

Metadata

Labels

enhancementFeature requests and improvements to existing functionalitynative-integrationIssues related to C/native code integrationperformanceIssues related to speed, memory usage, or efficiency

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions