Training Optimizations

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. They argue that a missing principle is making attention algorithms IO-aware – accounting for reads and writes between levels of GPU memory.
  • This paper by Dao et al. from Stanford in 2022 proposes FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. Specifically, FlashAttention reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length.
  • They analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. They also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method.
  • FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3x speedup on GPT-2 (seq. length 1K), and 2.4x speedup on long-range arena (seq. length 1K-4K).
  • FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).
  • The figure below from the paper shows: (Left) FlashAttention uses tiling to prevent materialization of the large \(N \times N\) attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the \(K\) and \(V\) matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of \(Q\) matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. FlashAttention does not read and write the large \(N \times N\) attention matrix to HBM, resulting in an 7.6x speedup on the attention computation.

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  • Scaling Transformers to longer sequence lengths has been a major problem in the last several years, promising to improve performance in language modeling and high-resolution image understanding, as well as to unlock new applications in code, audio, and video generation. The attention layer is the main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in the sequence length.
  • FlashAttention exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4x compared to optimized baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
  • They observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.
  • This paper by Dao from Princeton and Stanford proposes FlashAttention-2, with better work partitioning to address these issues. In particular, they (1) tweak the algorithm to reduce the number of non-matmul FLOPs, (2) parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy, and (3) within each thread block, distribute the work between warps to reduce communication through shared memory. These yield around 2x speedup compared to FlashAttention, reaching 50-73% of the theoretical maximum FLOPs/s on A100 and getting close to the efficiency of GEMM operations.
  • They empirically validate that when used end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s per A100 GPU (72% model FLOPs utilization).
  • The following figure from Sebastian Raschka summarizes FlashAttention-2:

Fast Transformer Decoding: One Write-Head is All You Need

  • Multi-head attention layers, as used in the Transformer neural sequence model, are a powerful alternative to RNNs for moving information across and between sequences. While training these layers is generally fast and simple, due to parallelizability across the length of the sequence, incremental inference (where such paralleization is impossible) is often slow, due to the memory-bandwidth cost of repeatedly loading the large “keys” and “values” tensors.
  • This paper by Shazeer from Google in 2019 proposes a variant called multi-query attention, where the keys and values are shared across all of the different attention “heads”, greatly reducing the size of these tensors and hence the memory bandwidth requirements of incremental decoding.
  • They verify experimentally that the resulting models can indeed be much faster to decode, and incur only minor quality degradation from the baseline.

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

  • Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference.
  • This paper by Ainslie et al. from Google Research (1) proposes a recipe for uptraining existing multi-head language model checkpoints into models with MQA using 5% of original pre-training compute, and (2) introduces grouped-query attention (GQA), a generalization of multi-query attention (MQA) which uses an intermediate (more than one, less than number of query heads) number of key-value heads.
  • The following figure from the paper presents an overview of grouped-query method. Multi-head attention has \(H\) query, key, and value heads. Multi-query attention shares single key and value heads across all query heads. Grouped-query attention instead shares single key and value heads for each group of query heads, interpolating between multi-head and multi-query attention.

  • MQA uses a single key-value head to speed up decoder inference but can lead to quality degradation. The authors propose a novel method to transform existing multi-head attention (MHA) language model checkpoints into models with MQA, requiring only 5% of the original pre-training compute.
  • The paper presents Grouped-Query Attention (GQA), an intermediate approach between multi-head and multi-query attention. In GQA, query heads are divided into groups, each sharing a single key and value head. This method allows uptrained GQA models to achieve near MHA quality with speeds comparable to MQA.
  • Experiments conducted on the T5.1.1 architecture across various datasets (including CNN/Daily Mail, arXiv, PubMed, MediaSum, Multi-News, WMT, and TriviaQA) show that GQA models offer a balance between inference speed and quality.
  • The study includes ablation experiments to evaluate different modeling choices, such as the number of GQA groups and checkpoint conversion methods. These provide insights into the model’s performance under various configurations.
  • The paper acknowledges limitations, such as evaluation challenges for longer sequences and the absence of comparisons with models trained from scratch. It also notes that the findings are particularly applicable to encoder-decoder models and suggests GQA might have a stronger advantage in decoder-only models.
  • They show that uptrained GQA achieves quality close to multi-head attention with comparable speed to MQA.

Longformer: The Long-Document Transformer

  • Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length.
  • This paper by Beltagy et al. from Allen AI in 2020 seeks to address this limitation, by introducing the Longformer with an attention mechanism that scales linearly with sequence length (commonly called Sliding Window Attention in the field), making it easy to process documents of thousands of tokens or longer.
  • Longformer’s attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention.
  • The figure below from the paper compares the full self-attention pattern and the configuration of attention patterns in Longformer.

  • Following prior work on long-sequence transformers, they evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8.
  • In contrast to most prior work, they also pretrain Longformer and finetune it on a variety of downstream tasks.
  • Their pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA. They finally introduce the Longformer-Encoder-Decoder (LED), a Longformer variant for supporting long document generative sequence-to-sequence tasks, and demonstrate its effectiveness on the arXiv summarization dataset.
  • The figure below from the paper illustrates the runtime and memory of full self-attention and different implementations of Longformer’s self-attention; Longformer-loop is nonvectorized, Longformer-chunk is vectorized, and Longformer-cuda is a custom cuda kernel implementations. Longformer’s memory usage scales linearly with the sequence length, unlike the full self-attention mechanism that runs out of memory for long sequences on current GPUs. Different implementations vary in speed, with the vectorized Longformer-chunk being the fastest.

Inference Optimizations

KV Cache

KV Cache in Transformer Models: A Comprehensive Summary

  • In the context of serving transformer models, the KV (Key-Value) cache is a mechanism used to store and reuse intermediate computations during the generation of a sequence of tokens, particularly in autoregressive models like GPT. This technique is one of the most commonly used tricks for speeding up inference with transformer-based models, especially large language models (LLMs).

KV Cache: Key-Value Cache

  • Key (K) and Value (V) Tensors: In a transformer model, each attention layer computes attention scores based on key (K) and value (V) tensors, which are derived from the input tokens. These tensors are used to calculate how much focus each token should have on the other tokens in the sequence.
  • Caching Self-Attention Values: During self-attention, the sequence of tokens is projected using three separate, linear projections: key projection, value projection, and query projection. The KV cache stores the results of the key and value projections for future decoding iterations to avoid recomputing them every time.

Autoregressive Decoding Process

  • Step-by-Step Process:
    1. Initial Sequence: Start with a sequence of textual tokens.
    2. Predict Next Token: Predict the next token.
    3. Update Input: Add this token to the input.
    4. Repeat: Repeat until the generation is finished.

Importance of KV Cache

  1. Efficiency:
    • Reduced Computation: By caching the key and value tensors, the model can reuse them in subsequent steps without recalculating them. This significantly reduces the computational overhead, especially for long sequences.
    • Faster Inference: Since the computation for previously generated tokens is bypassed, the overall inference time is reduced, allowing for faster token generation and real-time applications.
  2. Scalability:
    • Handling Long Sequences: For long sequences, recomputing the K and V tensors at each step would be prohibitively expensive. The KV cache allows the model to handle longer sequences more efficiently by storing and reusing past computations.
    • Memory Management: Efficiently managing the KV cache helps in maintaining a balance between memory usage and computational speed, crucial for deploying large transformer models in production environments.
  3. Practical Deployment:
    • Real-Time Applications: In applications like chatbots, real-time text generation, and interactive systems, the latency introduced by recomputing attention scores can be detrimental. The KV cache ensures that responses are generated quickly.
    • Resource Optimization: Efficient use of the KV cache can lead to better resource utilization on the hardware, such as GPUs or TPUs, which is essential for serving large-scale transformer models.

Why Not Cache the Query?

  • Query Matrix: The entries in the query matrix are only needed to compute the representations of prior tokens in the sequence, whose key and value representations are already stored in the KV cache. At each time-step, the new query input consists of the token at that time-step and all prior tokens, making it unnecessary to cache the query projections.

Updates to the KV Cache

  • During Decoding: Throughout autoregressive decoding, the key and value projections are cached. Each time a new token is added to the input, the new rows are computed as part of self-attention and added to the KV cache. The query projection for the new token is then used with the updated key and value projections to perform the rest of the forward pass.

Latency Optimization

  • Reduction in Latency: KV caching decreases the latency for generating the next token in an autoregressive setting starting from the second token. The prompt tokens are not cached initially, so the time to the first token is higher. However, as KV caching kicks in for subsequent tokens, the latency reduces, optimizing the overall response time.

Scaling to Multi-Head Self-Attention

  • Multi-Head Attention: While the explanation primarily considers single-head self-attention for simplicity, the same process applies to the multi-head self-attention used by LLMs. This involves performing the process in parallel across multiple attention heads.


  • In summary, the KV cache in transformer models is a critical optimization that enhances the efficiency and speed of sequence generation, making it a key component for deploying these models in real-world applications. The use of KV caching in autoregressive decoding processes, along with its role in latency optimization and scalability, makes it indispensable for serving transformer-based models efficiently.


If you found our work useful, please cite it as:

  title   = {Model Acceleration},
  author  = {Chadha, Aman},
  journal = {Distilled AI},
  year    = {2020},
  note    = {\url{https://aman.ai}}