Primers • Model Acceleration
- Training Optimizations
- Overview
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Fast Transformer Decoding: One Write-Head is All You Need
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
- Longformer: The Long-Document Transformer
- Inference Optimizations
- Model Quantization
- Operator Fusion
- Speculative Decoding
- FlashAttention and Efficient Attention Kernels
- Batching and Prefilling
- Prompt Caching
- Early Exit and Token Pruning
- Hardware-Aware Scheduling
- Comparative Analysis
- References
- Citation
Training Optimizations
Overview
-
Training optimizations for large language models (LLMs) focus on reducing computational and memory overhead during the training phase while preserving model quality. As LLMs scale in size and sequence length, traditional attention mechanisms and dense architectures become bottlenecks due to their high compute and memory requirements—most notably the quadratic complexity of self-attention.
-
This section explores innovations aimed at accelerating training through both algorithmic and systems-level enhancements. These include:
-
Memory-aware attention algorithms like FlashAttention and FlashAttention-2 that optimize data movement between GPU memory hierarchies (e.g., from HBM to SRAM), significantly reducing memory bandwidth usage and computation time. These approaches prioritize hardware efficiency through techniques such as tiling, recomputation, and parallelization of attention blocks.
-
Multi-query and grouped-query attention methods, such as those proposed in the Fast Transformer Decoding and GQA papers, which reduce redundancy in attention heads by sharing key/value projections. These techniques are especially valuable for speeding up decoding and inference but also reduce the number of parameters and computational cost during training.
-
Sparse and localized attention schemes like those introduced in Longformer, which replace global self-attention with a combination of local windowed and task-specific global attention. This approach reduces memory consumption and compute time from quadratic to linear with respect to sequence length, enabling efficient training on longer sequences.
-
-
Together, these methods represent a growing body of work that rethinks the Transformer architecture and its memory-compute tradeoffs. They aim to make LLM training more scalable, efficient, and accessible—paving the way for faster iterations and the deployment of increasingly capable models on constrained hardware. Subsequent sections provide a closer look at specific techniques and their empirical results.
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. They argue that a missing principle is making attention algorithms IO-aware – accounting for reads and writes between levels of GPU memory.
- This paper by Dao et al. from Stanford in 2022 proposes FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. Specifically, FlashAttention reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length.
- They analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. They also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method.
- FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3x speedup on GPT-2 (seq. length 1K), and 2.4x speedup on long-range arena (seq. length 1K-4K).
- FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).
- The figure below from the paper shows: (Left) FlashAttention uses tiling to prevent materialization of the large \(N \times N\) attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the \(K\) and \(V\) matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of \(Q\) matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. FlashAttention does not read and write the large \(N \times N\) attention matrix to HBM, resulting in an 7.6x speedup on the attention computation.
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Scaling Transformers to longer sequence lengths has been a major problem in the last several years, promising to improve performance in language modeling and high-resolution image understanding, as well as to unlock new applications in code, audio, and video generation. The attention layer is the main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in the sequence length.
- FlashAttention exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4x compared to optimized baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40% of the theoretical maximum FLOPs/s.
- They observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.
- This paper by Dao from Princeton and Stanford proposes FlashAttention-2, with better work partitioning to address these issues. In particular, they (1) tweak the algorithm to reduce the number of non-matmul FLOPs, (2) parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy, and (3) within each thread block, distribute the work between warps to reduce communication through shared memory. These yield around 2x speedup compared to FlashAttention, reaching 50-73% of the theoretical maximum FLOPs/s on A100 and getting close to the efficiency of GEMM operations.
- They empirically validate that when used end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s per A100 GPU (72% model FLOPs utilization).
- The following figure from Sebastian Raschka summarizes FlashAttention-2:
Fast Transformer Decoding: One Write-Head is All You Need
- Multi-head attention layers, as used in the Transformer neural sequence model, are a powerful alternative to RNNs for moving information across and between sequences. While training these layers is generally fast and simple, due to parallelizability across the length of the sequence, incremental inference (where such paralleization is impossible) is often slow, due to the memory-bandwidth cost of repeatedly loading the large “keys” and “values” tensors.
- This paper by Shazeer from Google in 2019 proposes a variant called multi-query attention, where the keys and values are shared across all of the different attention “heads”, greatly reducing the size of these tensors and hence the memory bandwidth requirements of incremental decoding.
- They verify experimentally that the resulting models can indeed be much faster to decode, and incur only minor quality degradation from the baseline.
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
- Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference.
- This paper by Ainslie et al. from Google Research (1) proposes a recipe for uptraining existing multi-head language model checkpoints into models with MQA using 5% of original pre-training compute, and (2) introduces grouped-query attention (GQA), a generalization of multi-query attention (MQA) which uses an intermediate (more than one, less than number of query heads) number of key-value heads.
- The following figure from the paper presents an overview of grouped-query method. Multi-head attention has \(H\) query, key, and value heads. Multi-query attention shares single key and value heads across all query heads. Grouped-query attention instead shares single key and value heads for each group of query heads, interpolating between multi-head and multi-query attention.
- MQA uses a single key-value head to speed up decoder inference but can lead to quality degradation. The authors propose a novel method to transform existing multi-head attention (MHA) language model checkpoints into models with MQA, requiring only 5% of the original pre-training compute.
- The paper presents Grouped-Query Attention (GQA), an intermediate approach between multi-head and multi-query attention. In GQA, query heads are divided into groups, each sharing a single key and value head. This method allows uptrained GQA models to achieve near MHA quality with speeds comparable to MQA.
- Experiments conducted on the T5.1.1 architecture across various datasets (including CNN/Daily Mail, arXiv, PubMed, MediaSum, Multi-News, WMT, and TriviaQA) show that GQA models offer a balance between inference speed and quality.
- The study includes ablation experiments to evaluate different modeling choices, such as the number of GQA groups and checkpoint conversion methods. These provide insights into the model’s performance under various configurations.
- The paper acknowledges limitations, such as evaluation challenges for longer sequences and the absence of comparisons with models trained from scratch. It also notes that the findings are particularly applicable to encoder-decoder models and suggests GQA might have a stronger advantage in decoder-only models.
- They show that uptrained GQA achieves quality close to multi-head attention with comparable speed to MQA.
Longformer: The Long-Document Transformer
- Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length.
- This paper by Beltagy et al. from Allen AI in 2020 seeks to address this limitation, by introducing the Longformer with an attention mechanism that scales linearly with sequence length (commonly called Sliding Window Attention in the field), making it easy to process documents of thousands of tokens or longer.
- Longformer’s attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention.
- The figure below from the paper compares the full self-attention pattern and the configuration of attention patterns in Longformer.
- Following prior work on long-sequence transformers, they evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8.
- In contrast to most prior work, they also pretrain Longformer and finetune it on a variety of downstream tasks.
- Their pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA. They finally introduce the Longformer-Encoder-Decoder (LED), a Longformer variant for supporting long document generative sequence-to-sequence tasks, and demonstrate its effectiveness on the arXiv summarization dataset.
- The figure below from the paper illustrates the runtime and memory of full self-attention and different implementations of Longformer’s self-attention;
Longformer-loop
is nonvectorized,Longformer
-chunk is vectorized, andLongformer-cuda
is a custom cuda kernel implementations. Longformer’s memory usage scales linearly with the sequence length, unlike the full self-attention mechanism that runs out of memory for long sequences on current GPUs. Different implementations vary in speed, with the vectorized Longformer-chunk being the fastest.
Inference Optimizations
Overview
-
Inference optimizations are a crucial area of research and engineering in the deployment of transformer models, particularly for real-time and resource-constrained environments. The goal is to minimize the computational cost and latency of running large language models (LLMs) without compromising their predictive accuracy. Optimizations during inference directly affect the responsiveness, scalability, and feasibility of these models in production systems.
-
One of the central challenges in inference is the autoregressive nature of many LLMs, where each token depends on the previously generated sequence. This leads to sequential dependencies that make naive inference expensive, especially for long sequences. To address this, a suite of optimization techniques has been developed to enhance the performance of transformer-based models during inference:
-
KV Caching: The KV cache in transformer models is a critical optimization that enhances the efficiency and speed of sequence generation, making it a key component for deploying these models in real-world applications. The use of KV caching in autoregressive decoding processes, along with its role in latency optimization and scalability, makes it indispensable for serving transformer-based models efficiently. It allows previously computed key and value projections from self-attention layers to be stored and reused during subsequent decoding steps, avoiding redundant computations. This dramatically reduces per-token inference time beyond the first token, supports long-sequence generation, and is essential for achieving low-latency, high-throughput serving in applications like chat, streaming, and interactive agents.
-
Model Quantization: Model quantization reduces the precision of weights and activations from 32-bit floating-point (FP32) to lower-bit formats such as INT8, FP8, or even 4-bit representations like INT4. This significantly cuts memory footprint and bandwidth usage, enabling deployment on smaller hardware and increasing throughput. Post-training quantization (PTQ) and quantization-aware training (QAT) are two common approaches. Quantized models benefit from faster matrix multiplications and lower energy consumption, and modern toolchains (e.g., NVIDIA TensorRT, Intel Neural Compressor) support hardware acceleration for quantized ops with minimal accuracy degradation.
-
Operator Fusion: Operator fusion consolidates multiple sequential operations—such as linear projections, bias addition, layer normalization, and activation functions—into a single computational kernel. This reduces the number of memory read/write operations and kernel launch overhead on GPUs or TPUs, improving execution efficiency. For example, fusing a dense layer and a ReLU activation into a single fused kernel reduces latency and allows for more effective use of SIMD or CUDA cores, which are otherwise underutilized with fragmented ops.
-
Speculative Decoding: Speculative decoding accelerates autoregressive generation by using a lightweight draft model to predict multiple future tokens in a single forward pass. These candidate tokens are then validated in parallel by the full, slower model. If validated, they are accepted en masse; otherwise, the generation rolls back. This pipeline reduces the number of expensive full-model invocations while maintaining generation fidelity. Approaches like Medusa, FastRAG, and NVIDIA’s Speculative Decoding with Prefill leverage this technique to boost throughput while preserving model output quality.
-
FlashAttention and Efficient Attention Kernels: FlashAttention is a memory-efficient attention algorithm that computes attention outputs in a tiled, fused, and GPU-friendly way, avoiding the need to materialize large intermediate attention matrices. It exploits GPU SRAM to keep frequently accessed blocks in high-speed memory and streams partial results to minimize memory bandwidth pressure. This approach scales better with sequence length and batch size than traditional softmax-based attention implementations. FlashAttention-2 and similar kernels (e.g., xFormers, Triton) are now standard in high-performance transformer inference stacks.
-
Batching and Prefilling: Dynamic and static batching combine multiple inference requests into a single execution batch, maximizing hardware utilization and amortizing kernel overhead across requests. Prefilling involves precomputing the KV cache for the input prompt before entering the autoregressive loop, especially in long-context use cases. Batching also facilitates padding and masking alignment across requests of varying lengths. Efficient batching strategies, like token-level (as in vLLM) or sequence-level, are crucial for low-latency, high-throughput inference in serving systems.
-
Prompt Caching: Caches the KV states of frequently used or repeated prompts—such as system instructions, few-shot exemplars, or user-defined templates—so they don’t need to be recomputed for each request. Particularly effective in chat or API-driven systems where the same initial context (e.g., “You are a helpful assistant…”) is used across sessions. By reusing prompt KV states, servers can skip prompt processing entirely and begin generation with the cache already initialized, significantly reducing time to first token and overall compute.
-
Early Exit and Token Pruning: Early exit allows transformer layers to terminate inference for specific tokens when confidence thresholds or entropy-based stopping criteria are met, saving computation on later layers. Token pruning dynamically removes tokens or attention paths deemed irrelevant during inference, based on learned importance scores or gating functions. These techniques reduce compute costs without heavily sacrificing model output quality, and are especially useful for deployment scenarios where speed is prioritized over full precision.
-
Hardware-Aware Scheduling: This optimization involves aligning inference workloads with the specifics of the underlying hardware—e.g., GPU memory hierarchy, tensor core availability, or pipeline concurrency. Scheduling strategies include operator placement, memory prefetching, stream prioritization, and load balancing across multi-GPU setups. For example, on NVIDIA GPUs, frameworks may utilize CUDA streams, shared memory, and kernel fusion to maximize throughput, while TPU inference may leverage XLA compilation for graph-level optimizations. Fine-tuned scheduling reduces contention, increases parallelism, and maximizes total inference throughput per watt.
-
KV Cache
Motivation
-
In the context of serving transformer models, the Key-Value (KV) cache is a core optimization technique for speeding up autoregressive decoding. It stores intermediate attention computations from previous decoding steps—specifically, the key and value tensors computed in the self-attention mechanism—so that they do not need to be recomputed at every new generation step. This dramatically reduces inference time and memory access overhead in LLMs, especially when generating long outputs.
-
Inference on transformers can be expensive because inference latency scales linearly with the number of layers \(l\) in the model. A naive transformer implementation recomputes the \(K\) and \(V\) representations for all tokens in the sequence at each decoding step, across all layers, which is computationally heavy. Specifically, for a single attention head, the total cost per predicted token is:
\[O(l \, n \, d^2)\]-
where:
- \(n\) = number of tokens seen so far (sequence length)
- \(l\) = number of layers (depth)
- \(d\) = model (embedding) dimension
-
-
Without caching, predicting each new token involves:
- Computing the key and value matrices for all past tokens and for every layer.
- Performing matrix multiplications of the form:
- where \(X\) is the layer input and \(W_K\), \(W_V\) are fixed weight matrices.
-
The KV cache optimization addresses this by reusing previously computed \(K\) and \(V\) representations, thus removing redundant computations for already processed tokens. This approach offers an \(n\)-times speedup in the sequence dimension, particularly impactful when \(n\) is large.
-
The following figure (source) illustrates a typical self-attention block in transformers:
Overview
Self-Attention Overview
-
In transformer models, each token attends to all previous tokens using a self-attention mechanism. For a sequence of input token embeddings \(X \in \mathbb{R}^{T \times d}\), the transformer computes:
-
Queries:
\[Q = X W_Q\] -
Keys:
\[K = X W_K\] -
Values:
\[V = X W_V\] -
where $W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k}$ are learned projection matrices, and $d_k$ is the head dimension.
-
-
The attention output is given by:
\[\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V\] -
In a naive implementation, for each decoding step we must compute \(K\) and \(V\) for all tokens in the current sequence, across all layers. If \(n\) is the number of tokens so far and \(l\) is the number of layers, this requires \(l \times (n−1)\) matrix multiplications per step, each of cost \(O(d^2)\), leading to:
\[\text{Cost per token} = O(l \, n \, d^2)\]
Caching Self-Attention Values
-
KV caching exploits two key properties:
- The model weights (\(W_K\) and \(W_V\)) are fixed during inference.
- The \(K\) and \(V\) representations for a given token depend only on that token and the fixed weights — they do not change in future steps.
-
Therefore, once we compute the \(K\) and \(V\) representations for a given (token, layer, head) tuple, we can store them and reuse them in all subsequent decoding steps.
-
At decoding step \(t\):
- Without caching: recompute \(K_{1:t} \quad \text{and} \quad V_{1:t}\) from scratch for all \(t\) tokens.
- With caching: reuse \(K_{1:(t-1)} \quad \text{and} \quad V_{1:(t-1)}\) compute only the new \(k_t \quad \text{and} \quad v_t\) and append them to the cache.
-
The following figure (source) illustrates the KV caching process, showing how only the new token’s \(K\) and \(V\) are computed while the rest are reused:
-
This optimization changes the cost from \(O(l \, n \, d^2)\) to \(O(l \, d^2)\) per decoding step — an \(n\)-times speedup in the sequence dimension.
-
This improvement is particularly significant for long sequences, where \(n\) can be in the hundreds or thousands.
Why Not Cache the Query?
- The query tensor \(Q\) is recomputed at every step because it depends on the current token’s embedding, which changes with each generation step. Unlike \(K\) and \(V\), which remain constant for already processed tokens, only the most recent \(Q\) is used in the self-attention operation, so caching it offers no benefit.
Autoregressive Decoding Process with Caching
-
Initial Sequence (Prefill Phase):
-
Given a prompt sequence \(S = [x_1, x_2, \dots, x_t]\) the model computes \(K\) and \(V\) tensors for all prompt tokens in all layers and stores them in the KV cache.
-
This step still incurs the full cost \(O(l \, t \, d^2)\) because we have no cached values yet.
-
After this prefill step, the model transitions to the decode phase, where we process one token per step.
-
-
Predict Next Token:
-
At decoding step t+1:
-
Compute the query vector \(q_{t+1} = x_{t+1} W_Q\) for the new token.
-
Retrieve all previous keys and values from the cache:
- Compute the attention output for the new token using:
-
-
-
Update Cache:
-
Compute the new key and value vectors for the current token:
\[k_{t+1} = x_{t+1} W_K, \quad v_{t+1} = x_{t+1} W_V\] -
Append these to the KV cache so they can be reused in future decoding steps:
\[K_{\text{cache}} \leftarrow [K_{1:t}, k_{t+1}]\] \[V_{\text{cache}} \leftarrow [V_{1:t}, v_{t+1}]\]
-
-
Repeat:
- Continue until the end-of-sequence (EOS) token is generated or the maximum token limit is reached.
Complexity Impact
-
Without caching: For each new token, recompute \(n\) key and value pairs for each layer, costing \(O(l \, n \, d^2)\) per step.
-
With caching: Only compute 1 new key and value per layer, reducing the per-step cost to \(O(l \, d^2)\).
-
This is an \(n-\)times speedup in the sequence dimension — for example, if n = 1000, the reduction in compute is enormous.
Implementation Details
Cache Tensor Shape
-
Assuming:
- Batch size B
- Max sequence length \(t\)
- Number of heads H
- Head dimension dₖ
- Model (embedding) dimension d
- Number of layers \(l\)
-
The KV cache is structured to store \(K\) and \(V\) for each (token, layer, head) tuple. In practice, the cache for a given layer and head has shapes:
- Key cache:
-
Value cache:
\[\text{V\_cache} \in \mathbb{R}^{B \times H \times T \times d_k}\] - Since autoregressive generation processes one token per step in the decode phase, only 1 new key and value vector is appended at each step for each layer and head.
- Efficient memory layout is crucial — contiguous buffers enable fast appends and reduce memory copy overhead.
Prefill Phase
-
When the prompt is first processed:
- The model computes \(K\) and \(V\) for all prompt tokens in every layer, filling the cache.
- This initial step has the same cost as the naive approach:
-
After this, we move into the decode phase, where caching delivers the performance benefits.
Caching Cost Analysis from the Blog
- Without caching:
- For a single head, computing \(K\) and \(V\) for all tokens so far costs \(O(l \, n \, d^2)\) per predicted token.
- With caching:
- We only compute \(K\) and \(V\) for the newest token \(O(l \, d^2)\) per predicted token.
- Speedup:
- An \(n\)-times reduction in the sequence dimension cost, where \(n\) is the number of tokens so far.
Why Not Cache the Query?
- As explained earlier, we recompute \(Q\) every step because it depends on the most recent token’s embedding. Only this latest \(Q\) is used in the attention calculation, so caching older queries provides no benefit and would only waste memory.
Updates to the KV Cache
- During autoregressive decoding, the \(K\) and \(V\) projections are cached for every processed token, across all layers and heads.
-
Each time a new token is generated:
- The model computes \(k_t\) and \(v_t\) for that token in each layer.
- These vectors are appended to the existing \(K_{cache}\) and \(V_{cache}\).
- The updated cache is then used to compute the attention output for the next token.
Latency Optimization
-
Without caching:
- At decoding step \(t\), the self-attention module recomputes keys and values for all \(t\) tokens, across all \(l\) layers.
- Computational cost per new token:
-
With caching:
- Only the key and value for the new token are computed, while the rest are reused from the cache.
- Computational cost per new token:
-
This represents an n-times speedup in the sequence dimension, where \(n\) is the number of tokens processed so far.
-
Latency growth:
- Without caching: latency per token grows quadratically with sequence length (\(O(T^2)\)).
- With caching: latency per token grows linearly (\(O(T)\)), since we only add one new vector per step.
-
In the blog, this was described as a major win for long sequences — the performance improvement is proportional to sequence length, so for large \(n\) (e.g., hundreds or thousands of tokens), the cost reduction is dramatic.
Practical Deployment Considerations
- Memory Management:
-
Large caches for long sequences consume substantial GPU memory. Practical systems may:
- Apply a sliding window to drop the oldest tokens.
- Truncate caches for very long contexts to fit hardware limits.
-
- Dynamic Batching:
- Batches often contain requests at different decoding stages. Each request must maintain its own KV cache state, and systems must efficiently manage per-request lookup and append operations.
- Cache Parallelism:
-
In multi-GPU serving:
- KV caches may be partitioned across devices.
- Synchronization strategies are needed to ensure all GPUs have the correct K and V for cross-device attention computations.
-
Multi-Head Attention and KV Cache
-
In practice, self-attention is implemented with multiple attention heads, each operating in a subspace of the embedding dimension. For head \(h\) in \({1, \dots, H}\), we have:
\[Q^{(h)} = X W^{Q^{(h)}}, \quad K^{(h)} = X W^{K^{(h)}}, \quad V^{(h)} = X W^{V^{(h)}}\] -
The attention outputs from each head are concatenated:
\[Q = \text{concat}(Q^{(1)}, Q^{(2)}, \dots, Q^{(H)})\]- and similarly for \(K\) and \(V\).
-
Caching in multi-head attention:
- The KV cache stores keys and values for every head and every layer.
- Shape for the key cache:
-
Shape for the value cache:
\[\text{V_{cache}} \in \mathbb{R}^{B \times H \times T \times d_k}\] -
Performance implications:
- Since each head’s KV cache is independent, the caching logic operates head-wise, but the storage is typically implemented as a unified tensor for efficiency.
- This unified tensor is arranged to be friendly to GPU tensor cores, enabling very fast read and write operations during decoding.
-
While KV caching greatly reduces the sequence dimension cost, the depth dimension (number of layers \(l\)) is still a significant contributor to compute. This leads to the KV Sharing idea, covered in detail in the section on KV Sharing — reusing \(K\) and \(V\) representations across the last half (or fraction) of layers to further cut computation. KV Sharing builds on KV caching, but attacks the problem from the layer dimension rather than the token dimension.
Summary of KV Cache Benefits
- Reduces repeated computation by storing and reusing \(K\), \(V\) tensors instead of recomputing them at every step.
- Enables efficient decoding in autoregressive generation by cutting per-step cost from \(O(l \, n \, d^2)\) to \(O(l \, d^2)\) — an \(n\)-times speedup in the sequence dimension.
- Optimized for hardware acceleration via unified tensor layouts that are friendly to GPU tensor cores.
- Scales well to large models and long contexts, with latency growing linearly rather than quadratically with sequence length.
- Maintains accuracy because cached \(K\) and \(V\) are identical to recomputed values, given fixed weights.
KV Sharing
- KV caching optimizes the sequence dimension (\(n\)) cost, but the depth dimension (\(l\)) — the number of layers — still incurs full computation for each layer’s \(K\) and \(V\).
- KV Sharing addresses this by reducing the cost of computing \(K\) and \(V\) along the depth dimension.
How KV Sharing Works
-
The core idea: share actual \(K\) and \(V\) representations (not just weight matrices) across the last fraction of layers.
-
For example, if we share across the last half of the layers (\(\frac{l}{2}\) layers):
- The final layer before the shared region computes \(K\) and \(V\) normally.
- All subsequent layers in the shared region reuse these \(K\) and \(V\) without recomputation, regardless of their inputs.
- Other parameters (e.g., \(W_Q\), MLP weights) remain distinct per layer.
-
Mathematically:
- Let \(L_{share}\) be the index of the first shared layer.
- For any layer \(j \geq L\_share\):
-
The following figure (source) illustrates KV Sharing across the last half of the layers, showing how a single computed \(K\) and \(V\) set is reused instead of recalculated:
FLOP Savings
- If the last \(\frac{l}{k}\) layers share \(K\) and \(V\), we avoid computing them in \(\frac{l}{k}\) layers entirely.
- FLOP reduction:
fraction of total K/V computation.
-
Combined with KV caching:
- KV caching cuts cost in \(n\) (sequence) dimension.
- KV sharing cuts cost in \(l\) (layer) dimension.
Why KV Sharing Can Work
- Empirical studies (referenced in the blog and the You Only Cache Once paper) show that in deep transformer models, the last few layers often produce correlated outputs.
- This suggests that later layers are mostly fine-tuning rather than introducing fundamentally new information.
- Reusing \(K\) and \(V\) in these layers therefore has minimal impact on output quality while significantly reducing compute and memory usage.
Memory Benefits
- No need to store K/V for the shared layers at all.
- Reduces memory footprint in both inference and training.
- Particularly valuable when serving long sequences, where cache size is dominated by \(B \times H \times T \times d_k \times l\) scaling.
Deployment Notes
- KV sharing must be considered at training time for best results, since models not trained with this constraint may suffer quality drops if sharing is applied post hoc.
- Works alongside KV caching — the former tackles depth, the latter tackles sequence length.
Model Quantization
- Model quantization is a technique used to reduce the precision of numerical values (typically weights and activations) in a neural network from high-precision formats like 32-bit floating point (FP32) to lower-precision formats such as INT8, FP8, or even INT4. This allows for faster inference, reduced memory usage, and lower power consumption, particularly on hardware that supports low-precision arithmetic.
- A detailed discourse on this topic is available in our Model Compression primer.
Why Quantize?
-
Quantization can lead to significant improvements in efficiency:
- Reduced Memory Footprint: An INT8 model consumes 75% less memory than its FP32 counterpart.
- Faster Arithmetic: Lower-precision operations (like INT8 or INT4 matmuls) are natively supported and highly optimized on modern accelerators (e.g., NVIDIA Tensor Cores, Intel AVX-512 VNNI).
- Lower Latency: With less data to move and faster compute kernels, quantized models offer reduced end-to-end inference time.
Types of Quantization
Post-Training Quantization (PTQ)
-
PTQ involves converting a pre-trained FP32 model to a lower-precision model without retraining. It works by calibrating the ranges of tensors using a small sample of data.
-
Key steps in PTQ:
-
Range Calibration: Identify the min/max values of weights and activations from a calibration dataset.
-
Scale and Zero-Point Calculation: For each quantized tensor, calculate:
\[q = \text{round}\left(\frac{r}{s}\right) + z\]-
where:
- \(r\) is the real-valued number
- \(s\) is the scale (i.e., step size)
- \(z\) is the zero-point to preserve zero mapping in the quantized domain
- \(q\) is the quantized value (e.g., 8-bit integer)
-
-
-
Weight and Activation Clipping: Clip values to fit within the representable range of the target bit-width (e.g., [-128, 127] for signed INT8).
Quantization-Aware Training (QAT)
-
QAT simulates quantization during training. Fake quantization layers are added to mimic low-precision computation while maintaining gradients in high precision.
-
Advantages:
- More accurate than PTQ for sensitive models (e.g., GPT, BERT).
- Allows the model to adapt to quantization errors during fine-tuning.
-
Implementation Details:
- Frameworks like PyTorch and TensorFlow include fake quantization modules (e.g.,
torch.quantization.FakeQuantize
). - Quant-dequant pairs are inserted in the model graph to simulate the behavior of actual quantized operations.
- Frameworks like PyTorch and TensorFlow include fake quantization modules (e.g.,
Static vs. Dynamic Quantization
- Static Quantization: Activations are quantized ahead of time using calibration. Requires representative input data and is more performant but less flexible.
- Dynamic Quantization: Weights are quantized ahead of time, but activations are quantized at runtime based on actual values. More flexible and easier to integrate but slightly slower.
Quantization in Transformers
-
In transformer models like GPT or BERT, quantization is applied to:
- Linear layers: Including query, key, value, and output projections in attention layers.
- GEMM-heavy blocks: MLP (feed-forward) layers.
- Embedding layers: Often quantized with special handling to preserve lookup efficiency.
-
Special Considerations:
- LayerNorm and Softmax are sensitive to quantization and often kept in FP32.
- Attention scores may require FP16 or FP32 to avoid instability.
- Mixed-precision quantization (e.g., FP8 weights with INT8 activations) is sometimes used.
Tooling and Frameworks
- NVIDIA TensorRT / FasterTransformer
- Intel Neural Compressor (INC)
- PyTorch Quantization Toolkit
- ONNX Runtime Quantization
-
BitsAndBytes (for 8-bit and 4-bit LLMs)
- These tools offer end-to-end pipelines for quantizing, validating, and deploying models.
Operator Fusion
-
Operator fusion is an inference optimization technique that combines multiple adjacent operations in a neural network computation graph into a single composite operation. This is done to reduce overhead from memory reads/writes, kernel launches, and inter-operation communication, especially on GPU- or TPU-based systems.
-
Fusion reduces latency and increases compute efficiency by keeping data in faster registers or shared memory, rather than flushing it out to slower global memory between every small operation.
Motivation
- Modern deep learning workloads often involve many small operations executed sequentially—e.g., matrix multiplications followed by bias addition, normalization, and non-linear activations:
-
Each of these operations might otherwise be implemented as a separate kernel. This leads to:
- Increased kernel launch overhead.
- Inefficient use of GPU parallelism.
- Repeated memory access and latency.
- Limited optimization opportunities for compilers.
-
By fusing them, the computation becomes more compact, minimizing overhead and maximizing performance.
Common Fusion Patterns
-
Some of the most commonly fused sequences in transformer inference include:
-
GEMM + Bias Add + Activation
- Example: \(Y = \text{ReLU}(X @ W + b)\)
- Typically fused in MLP layers.
-
LayerNorm + Add + Dropout
- Used in transformer residual blocks.
-
Query/Key/Value Linear Projections
- Three
Linear
ops fused into a single matmul followed by splitting heads.
- Three
-
Softmax + Masking
- In attention, softmax is often fused with masking logic to avoid branch divergence on GPUs.
-
Fusion in Transformers
-
In transformer architectures, operator fusion is especially valuable in:
-
Multi-Head Attention Blocks:
- Combine Q/K/V projections and reshape + transpose logic into a single kernel.
- Fuse attention score computation, masking, and softmax into one efficient operation.
-
Feed-Forward Networks (FFNs):
- Fuse two linear layers with intermediate activation (e.g., GELU or ReLU).
-
Implementation Details
- Fusion can be implemented in several ways:
Graph-Level Fusion (Ahead-of-Time)
-
High-level compilers like XLA (for TensorFlow) or TorchScript (for PyTorch) can analyze the computational graph and fuse operations during compilation.
-
Example in PyTorch:
@torch.jit.script
def fused_layer(x, w1, b1, w2, b2):
return F.relu(F.linear(x, w1, b1)) @ w2.T + b2
- TorchScript may fuse
linear + relu
into a single kernel.
Kernel-Level Fusion (Runtime)
-
Frameworks like NVIDIA’s TensorRT and FasterTransformer include hand-written CUDA kernels that combine multiple operations (e.g., QKV projection + transpose + scale + matmul) in one pass.
-
Example: A fused transformer kernel might compute:
qkv = fused_linear_bias_act(x); // one call
q, k, v = split_heads(qkv); // internal fused transpose and reshape
- This reduces global memory traffic and utilizes registers/shared memory for intermediate results.
3. Custom Kernel Generation
-
Libraries like TVM or Triton enable defining custom fused kernels in a hardware-optimized DSL. These can be compiled just-in-time for maximum throughput.
-
Example in Triton:
@triton.jit
def fused_gemm_relu(...):
# Define fused matmul + bias + relu logic using GPU thread blocks
Performance Impact
Operator fusion can lead to:
- 30–50% improvement in latency for attention blocks.
- Higher hardware utilization, especially on GPUs with tensor cores or vectorized ALUs.
- Reduced memory bandwidth pressure, which is often the bottleneck in LLM inference.
Tooling and Ecosystem
- TensorRT: Extensive fusion for transformer blocks.
- FasterTransformer: Fused QKV and FFN kernels.
- ONNX Runtime with Graph Optimizer: Automatic fusion passes.
- TorchScript + FBGEMM: Fusion of linear + activation ops.
- TVM / Triton: Customizable and tunable fusion kernels.
Speculative Decoding
-
Speculative decoding is an inference-time optimization technique designed to reduce the latency of autoregressive sequence generation in large language models (LLMs). Instead of generating one token at a time using the full, expensive model, speculative decoding uses a smaller, faster “draft” model to guess multiple tokens in parallel, then validates these guesses with the full model. If the guesses are correct, they are accepted as part of the output. Otherwise, they are partially or fully discarded and recomputed.
-
This method maintains the output quality of the original model while significantly improving throughput.
Motivation
-
Autoregressive decoding is inherently sequential. In a naive setup, the model generates one token, then feeds it back as input to generate the next. This sequential loop introduces latency and becomes a bottleneck during long-form generation.
-
Let:
- \(f\) be the full model (large, accurate but slow)
- \(g\) be the draft model (smaller, less accurate but fast)
-
Naively, generation requires \(T\) forward passes of \(f\) for a sequence of \(T\) tokens. Speculative decoding aims to reduce the number of times \(f\) is called.
Basic Algorithm
- Initialize Context: Use a prompt or previous tokens \(x\).
-
Draft Generation: Use the draft model \(g\) to generate a sequence of \(k\) speculative tokens:
\[y_1, y_2, ..., y_k = g(x)\] - Validation: Use the full model \(f\) to compute the log-probabilities \(p_f(y_t \| x, y_1, ..., y_{t-1})\).
-
Accept or Reject Tokens:
- Accept as many tokens as \(f\) agrees with (within a confidence threshold or by matching top-1 outputs).
- Rewind to the last agreed-upon token and resume with the draft model from there.
Pseudocode
x = initial_prompt
while not done:
draft_tokens = g.generate_next_k(x)
probs_f = f.get_probs(x + draft_tokens)
accepted_prefix = match(draft_tokens, probs_f)
x = x + accepted_prefix
Key Parameters
- Draft Model Quality: Must be fast enough to justify speculative overhead but good enough to match the full model reasonably often.
- Block Size \(k\): Number of speculative tokens generated per iteration. Larger blocks = fewer full model calls, but higher risk of rejection.
- Matching Strategy: Usually uses top-1 match or a log-prob threshold.
Mathematical View
- Let the probability of accepting each token be \(\alpha\). Then the expected number of full-model calls is:
- If \(\alpha \approx 0.7\) and \(k = 4\), we reduce full-model calls by nearly 3\(\times\).
Implementation Details
- Parallel Calls: \(f\) can validate all \(k\) tokens in one forward pass by using cached KV states and batched logits.
- KV Cache Management: Efficient speculative decoding updates the cache only after validation.
- Multimodel Serving: Systems like NVIDIA’s FasterTransformer or Hugging Face’s
transformers
can host both \(f\) and \(g\) concurrently with shared memory or GPU residency.
Notable Variants
- Medusa (Meta): Uses a tree-structured decoder to validate multiple candidates at once.
- FastRAG: Combines speculative decoding with retrieval-based models.
- Draft & Verify (Google): A formalized framework for plug-and-play speculative decoding with checkpointing.
Benefits
- Latency Reduction: 2\(\times\)–4\(\times\) speedup in decoding for long sequences.
- Full-Model Accuracy: Final output matches the output of the full model \(f\), so there’s no accuracy loss.
- Compatibility: Can be layered on top of existing decoding strategies (e.g., greedy, top-k, nucleus).
Limitations
- Requires additional memory and compute for the draft model.
- Effectiveness depends on alignment between the draft and full model distributions.
- Complex cache management and integration overhead.
FlashAttention and Efficient Attention Kernels
- In transformer models, self-attention is a core operation that enables the model to learn relationships between tokens. However, traditional attention implementations scale poorly with sequence length due to quadratic memory and compute complexity. FlashAttention and other efficient attention kernels address these bottlenecks by optimizing the attention computation to reduce memory overhead and improve performance.
Motivation
-
The standard attention computation involves the following operations for a sequence of length \(L\) and hidden dimension \(d\):
\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V\] -
This requires:
- Computing a full \(L \times L\) attention matrix (expensive for long sequences).
- Storing intermediate results like logits and softmax scores in global memory.
- Limited reuse of on-chip memory (registers, shared memory), resulting in bandwidth-bound performance.
-
FlashAttention addresses these inefficiencies by restructuring the attention algorithm to use memory-efficient block-wise computation.
FlashAttention: Key Concepts
-
Originally proposed in Dao et al., 2022, FlashAttention is a fused, tiled implementation of scaled dot-product attention that:
- Eliminates materialization of the full attention matrix.
- Uses tiling to partition queries, keys, and values into small blocks that fit in GPU shared memory.
- Fuses softmax, scaling, masking, and matmul into a single kernel.
High-Level Algorithm
- Load a block of queries \(Q_{i}\) and keys \(K_{j}\) into shared memory.
- Compute attention logits \(\frac{Q_i K_j^T}{\sqrt{d}}\) for the block.
- Apply mask and softmax in-place, updating the running sum of exponents and maximums for numerical stability.
- Accumulate partial outputs \(A_{i,j} = \text{softmax}(Q_i K_j^T / \sqrt{d}) V_j\) without storing intermediate attention matrices.
- Repeat across blocks until the full result is computed.
Numerical Stability
-
To avoid overflows when computing softmax across blocks, FlashAttention maintains:
- \(m_i = \max_{j} z_{ij}\) for each query row.
- \[s_i = \sum_{j} \exp(z_{ij} - m_i)\]
- and updates these across blocks using associative operators.
Implementation Details
- Written as a custom CUDA kernel.
- Uses shared memory to hold Q/K/V tiles and compute locally.
- Optimized to run in mixed precision (e.g., FP16 or BF16) for speed and memory efficiency.
- Compatible with dropout, masking, and rotary embeddings.
FlashAttention-2 Improvements
- Adds support for non-causal attention, variable-length sequences, and better warp-level parallelism.
- Removes redundant memory loads through more aggressive caching and loop unrolling.
- Enables backward pass efficiency, making it useful not only for inference but also for training.
Other Efficient Kernels
- xFormers (Meta): Modular attention implementations that support Flash, sparse, and memory-efficient variants.
- Triton-based Attention: Enables easy definition of fused attention kernels using Triton’s GPU DSL.
- PagedAttention (vLLM): Optimizes KV cache access for batch inference, reducing memory fragmentation and improving latency.
Performance Gains
FlashAttention reduces attention memory complexity from:
- \(\mathcal{O}(L^2)\) to \(\mathcal{O}(L)\) for memory consumption.
- Achieves 1.7–2.7\(\times\) speedup on A100 GPUs for long sequence lengths (> 1k tokens).
- Maintains exact attention output (within floating-point precision), unlike approximate methods.
Use in Inference
FlashAttention is especially beneficial for:
- Long-context models (e.g., 4k to 128k tokens).
- Multi-head attention, where per-head memory use adds up quickly.
- Deployment on GPUs with large shared memory (e.g., NVIDIA A100, H100).
Integration
Supported in:
- Hugging Face Transformers via
use_flash_attention_2=True
- PyTorch through custom CUDA extensions or Triton kernels
- DeepSpeed, FasterTransformer, and xFormers
Batching and Prefilling
- Batching and prefilling are inference-time optimization techniques that improve efficiency and throughput by better utilizing hardware and avoiding redundant computations. These are especially critical when serving large language models (LLMs) in real-time or at high concurrency.
Batching
- Batching refers to the process of grouping multiple inference requests into a single forward pass through the model. This increases hardware utilization, amortizes overhead, and reduces latency per request (on average), particularly on GPUs that are optimized for matrix-heavy workloads.
Motivation
-
Without batching, each request results in an under-utilized forward pass:
- Small input tensor \(\rightarrow\) Poor occupancy of GPU cores
- High overhead per kernel launch
- Wasted memory bandwidth
-
Batching solves this by aligning multiple requests into a tensor of shape:
\[\text{Batch Tensor: } (B, L, d)\]- where:
- \(B\) is batch size
- \(L\) is sequence length
- \(d\) is hidden dimension
- where:
Types of Batching
- Static Batching: Requests are grouped together at fixed time intervals. Simple but less flexible.
- Dynamic Batching: Requests are buffered and grouped at runtime based on heuristics like request arrival time, sequence length, or prompt similarity.
- Token-Level Batching: Pioneered by vLLM, this groups sequences by shared decoding step instead of sequence. Supports long-running generation jobs without blocking new ones.
- Asynchronous Batching: Uses request queues and a scheduler to decide when to batch based on hardware load.
Padding and Masking
-
Since sequences may vary in length, shorter ones are padded and masked accordingly. Padding increases memory cost but enables unified matrix operations.
-
Example:
- Sequence A:
[Hello, how, are, you]
\(\rightarrow\) length 4 - Sequence B:
[Hi]
\(\rightarrow\) length 1 - Batched input:
[[Hello, how, are, you], [Hi, PAD, PAD, PAD]]
- Sequence A:
Performance Benefits
- Higher throughput: GPUs can process large matrices in parallel.
- Lower kernel launch overhead.
- Amortized use of KV cache and memory bandwidth.
Prefilling
- Prefilling refers to the one-time computation of model activations (primarily KV cache) for the prompt or context tokens before autoregressive decoding begins.
Motivation
-
Autoregressive decoding separates the inference process into:
- Prompt Phase (Prefill): Process entire prompt to initialize the KV cache.
- Generation Phase (Decode): Generate one token at a time using cached keys and values.
-
The prompt phase is significantly more expensive because it processes multiple tokens without caching, while the decode phase uses KV caching for each new token.
Prefilling Logic
-
Given a prompt of \(n\) tokens:
- The model performs a full forward pass to compute attention outputs for all \(n\) positions.
-
During this, it initializes the KV cache tensors:
\[K_{1:n}, V_{1:n}\] - These are used in all subsequent generation steps to avoid recomputation.
Optimizations
- Fused Prefill Kernels: Libraries like FasterTransformer use specialized kernels to batch and prefill KV caches in a single efficient pass.
- Prompt Sharing: If multiple requests use the same prompt (e.g., “You are a helpful assistant…”), cache the prefilled results and reuse them across requests.
- Layer-Wise Streaming: Some implementations stream KV cache population layer-by-layer to overlap computation and memory operations.
Real-World Use
In production systems:
- Prompt prefill is often the dominant source of latency, especially with long prompts (e.g., 1k+ tokens).
- Prefilling is not cacheable unless the prompt is reused. That’s where prompt caching (next section) comes in.
- Systems may delay decoding until all requests in a batch complete their prefill phase.
Performance Benefits
- Avoids redundant computation across decoding steps.
- Enables efficient reuse of memory and attention context.
- Critical for long-context inference and multi-user serving.
Prompt Caching
-
Prompt caching is an inference-time optimization that reuses the computed key-value (KV) attention states for frequently occurring or repeated prompt tokens. It eliminates the need to recompute the prompt phase of autoregressive decoding, which is typically the most computationally expensive part of the inference pipeline for long prompts.
-
This technique is especially effective in systems with repeated system messages, user templates, or static few-shot examples.
Motivation
- During autoregressive generation, transformer models process the prompt (or context) once to initialize the attention cache. For a prompt of length \(n\), this involves a full forward pass through all transformer layers to compute the KV tensors:
-
This prefill step is expensive and must be repeated for every new request — even if the prompt is the same.
-
Observation: Many applications use identical or highly similar prompts repeatedly. For example:
- Instructional prompts like: “You are a helpful assistant.”
- Few-shot templates in customer support bots.
- System prompts in chat APIs.
-
Prompt caching avoids repeated prefill for these common contexts.
Basic Mechanism
-
Cache Initialization:
-
Compute and store KV tensors for a given prompt:
\[\text{KV}_{\text{prompt}} = f(\text{prompt})\] -
Store in memory or disk with a unique key (e.g., hash of token IDs).
-
-
Cache Lookup:
- For each incoming request, compute a cache key from its prompt.
- If a match is found, retrieve KV tensors instead of recomputing them.
-
Continue Decoding:
-
Begin token-by-token generation using the cached KV state:
\[\text{Generate}(x_{n+1} \mid \text{KV}_{\text{prompt}})\]
-
Implementation Details
Cache Granularity
- Full Prompt Cache: Caches the entire KV cache of a prompt. Simple and effective but can use a lot of memory.
- Prefix Sharing: If prompts differ by suffix (e.g.,
Prompt A + User 1
andPrompt A + User 2
), share the KV prefix and compute only the delta. - Subgraph Caching: In more advanced systems, only the first few layers or tokens may be cached.
Cache Storage
- In-Memory KV Cache: For maximum performance, use GPU or CPU memory with LRU eviction.
- On-Disk Cache: Slower but scalable for cold-start scenarios.
- Keyed by Hash: Tokenized input is hashed using SHA or CRC to form a cache key. Some systems normalize prompts before hashing.
Integration with Serving Systems
- Requires cache-aware batch scheduling.
- Works best when integrated with dynamic batching and token-level schedulers (e.g., vLLM).
- May include cache warming: preloading common prompts at system startup.
Performance Impact
-
Let:
- \(T_p\) = time to prefill prompt
- \(T_d\) = time per token for decode
-
For long prompts (e.g., 1000+ tokens), \(T_p \gg T_d\), so caching the prefill can save 80–95% of per-request compute for repeated prompts.
Applications
- Chat APIs: System messages or few-shot exemplars remain fixed across turns.
- Agent Frameworks: Tools like LangChain often replay the same template structure.
- Batch Inference: Multi-user prompts often share context headers (e.g., “Summarize the following…”).
Limitations
- Prompt cache is only useful for identical or prefix-matching prompts.
- Memory usage scales with prompt length and cache size.
- May add overhead for hash computation or miss rate handling.
- Not helpful for fully dynamic, unique user inputs.
Early Exit and Token Pruning
-
Early exit and token pruning are inference-time optimizations designed to reduce computation in large transformer models by selectively skipping or trimming parts of the computation graph that have diminishing contribution to the final output. These methods exploit redundancy in token representations and layer-wise stability in transformer models.
-
Both techniques aim to speed up inference without significantly affecting model output quality, making them valuable in latency-sensitive or resource-constrained applications.
Early Exit
- Early exit allows the model to stop processing certain tokens or even entire sequences at intermediate layers if the model’s confidence in the prediction is already high.
Motivation
- Transformer models use a fixed number of layers (e.g., 24 or 96), but not all tokens require the full depth to make a confident prediction. For example, easily classifiable tokens (like punctuation or common stopwords) may converge earlier than rare or ambiguous tokens.
Mechanism
- At each transformer layer \(l\), evaluate a confidence metric based on the current token representation:
-
Entropy-Based Confidence:
- Compute the softmax output \(p^{(l)}\) from the current logits.
-
Compute entropy:
\[H(p^{(l)}) = - \sum_i p_i^{(l)} \log p_i^{(l)}\] - If entropy \(<\) threshold, consider the prediction confident enough to exit.
-
Cosine Similarity to Previous Layer:
- If representation at layer \(%l\)% is similar to layer \(%l-1\)%, the token may have converged.
-
Learned Gates:
- Add a small classification head to each layer to learn exit decisions during training (as in BranchyNet or LayerDrop approaches).
Implementation
- Models like BERT with Early Exit (DEEPL) implement classifier heads at multiple depths.
- Hugging Face
transformers
has prototype support for early exit in sequence classification. - Requires threshold tuning to balance accuracy and latency.
Benefits
- Reduces average inference depth (e.g., from 24 layers to 12–16 for many tokens).
- Saves computation for simpler or high-confidence examples.
- Ideal for classification or QA tasks where tokenwise prediction is not necessary.
Limitations
- Adds overhead from confidence computation at intermediate layers.
- Not widely adopted in generation tasks due to sequential dependencies between tokens.
Token Pruning
- Token pruning reduces the number of tokens that are propagated through the deeper layers of a transformer by identifying and removing tokens with low contextual importance.
Motivation
-
In many attention-based computations, some tokens contribute very little to the output. For example, padding tokens or tokens with low attention weights to the rest of the sequence.
-
Pruning these tokens saves compute in later layers, especially in long-context models or batch scenarios.
Mechanism
-
Attention-Based Pruning:
-
Compute the attention score variance or total attention mass a token receives:
\[\alpha_i = \sum_{j} \text{Attention}(x_i, x_j)\] -
Prune tokens with low total attention received or given.
-
-
Top-k Token Selection:
- Keep only the top-k most important tokens per head or per sequence based on learned importance scores.
-
Dynamic Thresholding:
- Use learned or rule-based thresholds to drop tokens whose impact is below a tunable cutoff.
-
Progressive Pruning:
- Start with full tokens, and prune more aggressively as layers go deeper.
Implementation
- Typically done at attention module boundaries.
- Can be combined with sparse attention mechanisms.
- Token indices need to be tracked to reconstruct output or map back to the original sequence.
Benefits
- Reduces computation in deeper layers, especially for long sequences.
- Improves throughput with minimal impact on quality in summarization, QA, and retrieval tasks.
- Can be applied during training for alignment with inference.
Limitations
- May degrade quality if pruning is too aggressive or incorrectly calibrated.
- Requires complex index tracking and masking logic.
- Harder to apply in autoregressive settings where all tokens are sequentially dependent.
Tools and Research
- DeLighT, LayerDrop, and EarlyBERT for early exit variants.
- SparseFormer, Synthesizer, and Longformer introduce related token reduction ideas.
- Hugging Face and NVIDIA’s Megatron support token pruning hooks in research branches.
Hardware-Aware Scheduling
-
Hardware-aware scheduling refers to a set of optimization strategies that tailor the execution of neural network inference to the specific architecture and performance characteristics of the underlying hardware—such as GPUs, TPUs, or specialized accelerators. These optimizations aim to improve compute throughput, memory utilization, and latency by orchestrating how and when operations are executed.
-
This is especially important for transformer inference, where workloads are large, heterogeneous (e.g., KV cache lookups, matrix multiplies, normalization), and sensitive to memory bandwidth and parallelism.
Motivation
-
Transformer inference involves many stages of computation and memory access:
- Matrix multiplications (GEMMs) in attention and feed-forward blocks.
- Data movement between layers and devices.
- KV cache management and resizing.
- Softmax, activation, and normalization operations.
-
Without careful scheduling, bottlenecks can emerge due to:
- Underutilized compute units (e.g., Tensor Cores).
- Memory stalls and cache thrashing.
- Synchronization overhead between layers or streams.
-
Hardware-aware scheduling optimizes these execution flows to keep the pipeline full and latency low.
Core Techniques
Stream Parallelism
-
Modern GPUs support multiple concurrent execution streams (e.g., via CUDA). In transformer inference:
- Use separate CUDA streams for different model stages (e.g., one for KV cache update, one for GEMM).
- Overlap memory copies (e.g.,
cudaMemcpyAsync
) with compute to hide latency.
-
Example:
cudaMemcpyAsync(..., stream1); cublasGemmEx(..., stream2); // runs concurrently with stream1
Tensor Core Utilization
-
Tensor cores are specialized units in NVIDIA GPUs for low-precision matrix ops (e.g., FP16, BF16, INT8). To maximize their usage:
- Ensure all matrix multiplications are aligned to multiple-of-8 dimensions.
- Use fused kernels to eliminate intermediate FP32 conversions.
- Prefer mixed-precision pipelines (AMP / FP16) for higher throughput.
-
Libraries like cuBLAS, FlashAttention, and TensorRT handle these optimizations automatically when configured correctly.
Operator Placement and Reordering
-
Efficient inference scheduling may involve reordering or co-locating operations based on:
- Memory locality: Fuse or group operations that share data.
- Execution time: Prioritize long-running ops earlier in the pipeline.
- Device affinity: Keep frequently accessed data on the same GPU or chip.
-
Example: Run attention blocks first in multi-layer transformer if they dominate compute time, allowing FFNs to be prefetched concurrently.
KV Cache Management
-
Efficient KV cache handling is essential in decoder models:
- Paged KV Cache: Used in systems like vLLM, stores KV in contiguous memory pages and allows random-access updates.
- Memory Pools: Preallocate KV buffers for each request and reuse them to avoid memory fragmentation.
- Lazy Allocation: Delay cache instantiation until first generation step to save memory for short prompts.
Pipeline and Model Parallelism
-
In large-model deployments:
- Pipeline Parallelism: Distribute transformer layers across devices. Stage execution overlaps compute and communication.
- Tensor Parallelism: Split individual tensor dimensions (e.g., weights) across devices for large GEMMs.
-
Combined, these allow serving models with billions of parameters across multiple GPUs efficiently.
Custom Kernel Scheduling
-
Frameworks like Triton and TVM allow defining and tuning custom kernels:
- Auto-tune tiling sizes and shared memory usage.
- Schedule GPU threads based on warp/block-level parallelism.
- Implement custom token-wise or layer-wise scheduling logic.
Cache and Memory Prefetching
- Use
__prefetch
instructions or async loads to bring data into shared memory before it is needed. - Overlap KV fetches with matmul execution to hide memory latency.
Deployment-Aware Strategies
- Load Balancing: Use dynamic batching queues with GPU-aware request routing (e.g., based on latency or memory pressure).
- Thread Affinity: Bind computation to specific CPU cores or NUMA zones in CPU-bound systems.
- Execution Profiling: Use profilers like NVIDIA Nsight Systems or PyTorch Profiler to tune for bottlenecks.
Ecosystem Support
- NVIDIA TensorRT and FasterTransformer: Hardware-aware fused kernels and scheduling policies.
- ONNX Runtime (ORT): Execution providers tuned for different hardware (CUDA, DirectML, TensorRT).
- DeepSpeed, vLLM, Triton, and TVM: Offer fine-grained control over scheduling and memory layout.
Performance Impact
-
Hardware-aware scheduling can yield:
- 1.5\(\times\)–4\(\times\) speedup over naive scheduling for long sequences or large batches.
- Better multi-GPU scaling for high-throughput inference.
- Lower latency variability in real-time serving environments.
Comparative Analysis
Technique | Purpose | Key Benefits | Primary Use Cases | Implementation Notes |
---|---|---|---|---|
KV Caching | Reuse attention keys/values from previous tokens | Reduces per-token latency after first step | Autoregressive decoding (GPT, LLaMA) | Requires careful cache management; starts from second token onward |
Model Quantization | Use lower-precision weights/activations | Reduces memory and compute cost | Edge inference, high-throughput serving | INT8/PTQ for speed; QAT for better accuracy; needs hardware with quantization support |
Operator Fusion | Combine adjacent ops into single kernel | Reduces memory access and kernel launch overhead | Attention blocks, FFNs, LayerNorm + activation | Use graph compilers (XLA, TorchScript), or fused CUDA kernels (TensorRT, FasterTransformer) |
Speculative Decoding | Use draft model to guess multiple tokens | Reduces number of full-model forward passes | Long-form generation, chatbots | Needs a lightweight auxiliary model; uses top-1 match or log-prob threshold for validation |
FlashAttention & Kernels | Memory-efficient attention computation | Reduces memory usage and boosts speed | Long-sequence LLMs, multi-head attention | Implemented with CUDA (FlashAttention), or Triton/xFormers; avoids storing full attention matrix |
Batching | Process multiple requests together | Increases throughput and GPU utilization | High-concurrency inference (API servers, batch jobs) | Dynamic and token-level batching supported in vLLM, DeepSpeed, TensorRT |
Prefilling | Precompute KV cache from prompt tokens | Avoids recomputation in autoregressive models | Chat and generation tasks with long prompts | Often paired with batching; prompt KV cache initialized before decoding begins |
Prompt Caching | Cache KV states of repeated prompts | Saves time and compute on repeated static contexts | Chat APIs, few-shot prompt templates | Requires hashing/tokenizing prompt and storing cache; memory usage grows with cache diversity |
Early Exit | Stop processing tokens/layers early based on confidence | Reduces per-token compute in deep models | Classification, QA tasks | Needs entropy or learned gating logic; difficult to apply in token-dependent generation |
Token Pruning | Discard low-importance tokens during inference | Reduces sequence length in deeper layers | Long-sequence summarization, QA | Attention-based importance scoring; careful masking and index tracking required |
Hardware-Aware Scheduling | Optimize kernel execution for specific hardware | Maximizes throughput and minimizes latency | All transformer-based workloads | Includes stream parallelism, memory prefetch, cache layout, tensor core tuning, and multi-GPU distribution |
References
- KV Caching Explained: Optimizing Transformer Inference Efficiency
- Gaurav’s Blog – Efficient AI: KV Caching and KV Sharing
- Let’s build GPT: from scratch, in code, spelled out
Citation
If you found our work useful, please cite it as:
@article{Chadha2020DistilledModelAcceleration,
title = {Model Acceleration},
author = {Chadha, Aman},
journal = {Distilled AI},
year = {2020},
note = {\url{https://aman.ai}}
}