Flash Attention is an algorithm that speeds up the training and inference of transformer models.
How does Flash Attention work?
Many modern transformer models use a mechanism called “attention” to focus on important parts of their input. It’s like how humans pay attention to key words in a sentence. The problem, though, is that traditional attention computations are slow and memory-hungry, especially for long sequences of data (like long documents or high-resolution images).
Flash Attention rethinks how attention is computed on GPUs. It uses smart memory management techniques to do the same calculations much faster and with less memory. In particular, it carefully manages how data moves between different levels of memory on a GPU.
When should you use Flash Attention
You should consider using Flash Attention if:
- You’re working with large language models or any AI that uses attention mechanisms (like transformers) and you want to speed up training or inference.
- You have very long input sequences (thousands or tens of thousands of tokens) or large batch sizes
- Scenarios where GPU memory is a bottleneck
By using Flash Attention in these contexts, you can expect:
- Faster training and inference times
- Ability to handle longer sequences without running out of memory
- Potential to increase model size or batch size within the same memory constraints
Flash Attention Versions
There have been several versions of Flash Attention. After the original Flash Attention, released in 2022, Flash Attention 2 was released in early 2023. It included optimizations for memory access patterns and causal attention, achieving up to 2x speedup over its predecessor.
The latest iteration, Flash Attention 3, incorporates enhancements specifically designed for NVIDIA’s Hopper GPU architecture, (e.g. H100s) allowing for even greater efficiency and performance. This version leverages advanced techniques to maximize GPU utilization and further improve speed and memory efficiency.
How to use Flash Attention
The easiest way to use Flash Attention is to use a training or inference framework that has it integrated already. Below, we cover the most popular frameworks and the status of their integration with Flash Attention.
PyTorch
PyTorch has native support for Flash Attention 2 as of version 2.2. You can use it directly in your PyTorch models. To enable Flash Attention in PyTorch, you typically need to select Flash Attention as the attention mechanism in the Scaled Dot Product Attention backend.
Hugging Face Transformers
The Transformers library supports Flash Attention for certain models. You can often enable it by setting the attn_implementation="flash_attention_2"
parameter when initializing a model. However, support may vary depending on the specific model architecture.
vLLM
vLLM natively takes advantage of Flash Attention 2 as of v0.1.4. You don’t need to enable it separately.
Text Generation Inference (TGI)
Flash Attention is enabled by default for TGI. However, its usage may vary depending on the specific models, even when compiled.
The system aims to utilize Flash Attention whenever possible due to its advantages, but it will revert to alternative methods if any issues arise.
Separate Implementation
While these frameworks often include Flash Attention or similar optimizations, you can also install it using pip:
pip install flash-attn
or clone the repo and install it from source.
Make sure that you have its dependencies installed, including:
- PyTorch: Ensure you have PyTorch version 1.12 or above installed.
- CUDA: A compatible version of the CUDA toolkit is necessary for GPU support.
- NVIDIA cuDNN: This library is recommended for optimized performance on NVIDIA GPUs. For more information about the CUDA toolkit, refer to our CUDA guide.
For a full example of how to run a transformers model on cloud compute with Flash Attention 3, you can refer to our Flux tutorial here.