Overview

  • Deploying machine learning models directly on edge devices such as smartphones, tablets, or embedded systems offers key advantages in privacy, latency, and user experience. On-device processing ensures that data remains local to the device, significantly reducing the risk of exposure from data transmission or centralized storage breaches. This is particularly critical for applications in conversational AI and NLP, where interactions often involve personal or sensitive information.

  • However, the computational and memory demands of modern machine learning models—especially large language models—pose a major barrier to efficient on-device deployment. These models are typically too large and resource-intensive to run in real-time on devices with limited hardware capabilities. As models continue to scale in size and complexity, challenges such as increased latency, power consumption, and hardware constraints become even more pronounced.

  • To address these limitations, a wide array of model compression and optimization techniques has been developed. These include model quantization (static, dynamic, and quantization-aware training), structured and unstructured pruning, knowledge distillation (response-based, feature-based, relation-based), low-rank decomposition, activation-aware quantization, operator fusion, and mixed precision training. Advanced post-training quantization methods such as AWQ, SmoothQuant, SpinQuant, AWEQ, and FPTQuant further push the boundaries of efficient model deployment.

  • Collectively, these techniques aim to reduce model size, lower computational complexity, and accelerate inference—without significantly compromising accuracy. By enabling smaller, faster, and more power-efficient models, these strategies make it increasingly feasible to run advanced AI applications directly on user devices, supporting both better privacy and smoother user experiences.

Model Compression Techniques

  • Modern generative models often contain between 100 billion to 1 trillion parameters. Since each parameter (as a 32-bit float) consumes 4 bytes, the memory footprint can scale from 400 GB to over 4 TB. This makes them prohibitively large for deployment on edge devices, where memory and compute resources are highly constrained. To enable on-device AI, a wide range of model compression techniques have been developed. Below, we visually and conceptually summarize the five core strategies widely used in industry and research.

  • Model Pruning
    • Pruning reduces model size by removing weights that have minimal impact on overall performance.

      • Unstructured pruning eliminates individual weights based on their magnitude (e.g., L1/L2 norm) or gradient impact (e.g., first or second derivative of the loss).
      • Structured pruning goes further by removing entire neurons, filters, or layers, which can directly lead to faster inference on hardware accelerators.
      • In practice, pruning is often iterative: prune, retrain, evaluate, and repeat to recover performance loss.
  • Model Quantization
    • Quantization reduces the precision of weights and activations, typically from 32-bit floats to 8-bit integers, yielding up to a 4× reduction in model size and significant speed-ups using optimized kernels.

      • Post-training quantization applies precision reduction directly on a trained model and may include heuristic corrections (e.g., bias correction, per-channel scaling).
      • Quantization-aware training (QAT) simulates quantization noise during training, allowing the model to adapt and maintain higher accuracy under reduced precision.
      • Advanced quantization strategies like AWQ, SmoothQuant, and AWEQ refine post-training quantization by adjusting scaling factors or reweighting attention layers.
  • Low-Rank Decomposition
    • Many neural networks, especially transformer-based architectures, contain large weight matrices that can be approximated as a product of smaller matrices.

      • For example, an N×N matrix can often be replaced by two N×k matrices (with k « N), reducing space complexity from O(N²) to O(Nk).
      • Methods like SVD (singular value decomposition) or CP decomposition are used here.
      • In practice, fine-tuning the decomposed model is essential to restore original performance.
  • Knowledge Distillation
    • Knowledge distillation trains a compact student model to mimic a larger teacher model.

      • Response-based distillation focuses on matching the output logits or probabilities.
      • Feature-based distillation aligns intermediate representations between teacher and student.
      • Relation-based distillation preserves inter-feature dependencies (e.g., attention maps).
      • Distillation is often combined with other techniques (e.g., quantization or pruning) to maximize efficiency.
  • Lightweight Model Design
    • Rather than compressing an existing model, lightweight design focuses on creating efficient architectures from the ground up.

      • Examples include MobileBERT, DistilBERT, TinyBERT, and ConvNeXt-T, which use smaller embedding sizes, depth-wise separable convolutions, fewer transformer blocks, or other architectural efficiencies.
      • Empirical design choices are often guided by NAS (Neural Architecture Search) or latency-aware loss functions.
  • The illustration below (source) summarizes how these different compression methods contribute to reducing model size and enabling efficient deployment across platforms, including on-device scenarios.

Model Compression Techniques

Quantization

Background: Precision

  • Before we talk about quantization, let’s learn about precision. Per the IEEE 754 floating point specification, floating point formats vary by bit-width and how they allocate bits to represent sign, exponent, and mantissa (also called significand, which contains the significant digits or fractional precision). For example:

    • Double-precision uses 64 bits
    • Single-precision uses 32 bits
    • Half-precision uses 16 bits

IEEE 754 Floating Point Standard

  • The IEEE 754 standard defines the binary representation of floating point numbers used in nearly all modern hardware and programming environments. A floating-point number is composed of three parts:

    • Sign bit (\(S\)): 1 bit indicating positive or negative.
    • Exponent (\(E\)): Encodes the range (scale) of the number.
    • Mantissa or significand (\(M\)): Encodes the precision.
  • The general representation for a binary floating-point number is:

\[\text{value} = (-1)^S \times 1.M \times 2^{E - \text{bias}}\]
  • Each format—half, single, and double—allocates a different number of bits to these components, balancing precision and range against memory usage and compute requirements.

  • The following figure (source) shows the IEEE 754 standard formats for floating-point numbers, illustrating the bitwise layout of the signed bit, exponent, and significand across double (64-bit), single (32-bit), and half (16-bit) precision representations.

Half-Precision (float16)

  • Bit layout: 1 sign bit, 5 exponent bits, 10 mantissa/significand bits.
  • Total bits: 16
  • Exponent bias: 15
  • Dynamic range: Approximately \(6 \times 10^{-5}\) to \(6.5 \times 10^4\)
\[\text{value}_{\text{float16}} = (-1)^S \times 1.M_{10} \times 2^{E - 15}\]
  • Half-precision is mostly used during inference, especially in low-power or memory-constrained environments such as mobile devices or embedded hardware. It offers limited range and precision, and is generally not suitable for training deep networks unless special care is taken (e.g., using loss scaling or float32 accumulations).

  • GPU Considerations:

    • Many GPUs, like NVIDIA’s Volta, Turing, Ampere, Hopper, Blackwell, include specialized hardware units called Tensor Cores optimized for float16 operations.
    • float16 can be processed at higher throughput than float32, enabling significant speedups for matrix multiplications during inference.
    • Often paired with Mixed-Precision Training (MPT), where activations and weights are stored in float16, but gradients are accumulated in float32.

Single-Precision (float32)

  • Bit layout: 1 sign bit, 8 exponent bits, 23 mantissa bits.
  • Total bits: 32
  • Exponent bias: 127
  • Dynamic range: Approximately \(1.4 \times 10^{-45}\) to \(3.4 \times 10^{38}\)
\[\text{value}_{\text{float32}} = (-1)^S \times 1.M_{23} \times 2^{E - 127}\]
  • float32 is the default numerical format for most deep learning frameworks and hardware. It provides a good balance between precision and range, making it robust for both training and inference.

  • GPU Considerations:

    • Supported natively on all modern GPUs.
    • Most general-purpose ALUs (Arithmetic Logic Units) on the GPU are designed to process float32 efficiently.
    • Slower than float16 or bfloat16 in terms of throughput and power usage but more accurate.

Double-Precision (float64)

  • Bit layout: 1 sign bit, 11 exponent bits, 52 mantissa bits.
  • Total bits: 64
  • Exponent bias: 1023
  • Dynamic range: Approximately \(5 \times 10^{-324}\) to \(1.8 \times 10^{308}\)
\[\text{value}_{\text{float64}} = (-1)^S \times 1.M_{52} \times 2^{E - 1023}\]
  • float64 is typically used in scientific computing, numerical simulations, and applications requiring high precision. It is rarely used in deep learning because its benefits are minimal for most ML tasks, while the compute and memory costs are high.

Brain Floating Point (bfloat16)

  • Bit layout: 1 sign bit, 8 exponent bits, 7 mantissa bits
  • Total bits: 16
  • Exponent bias: 127 (same as float32)
  • Dynamic range: Approximately \(1.2 \times 10^{-38}\) to \(3.4 \times 10^{38}\)
\[\text{value}_{\text{bfloat16}} = (-1)^S \times 1.M_{7} \times 2^{E - 127}\]
  • bfloat16 (Brain Floating Point 16) was introduced by Google for training deep neural networks. Unlike float16, which reduces both exponent and mantissa bits, bfloat16 keeps the same exponent width as float32 (8 bits) but reduces the mantissa to 7 bits.

  • This design retains the dynamic range of float32, which makes it far more robust to underflow/overflow issues during training compared to float16.

  • However, the precision is reduced, since fewer mantissa bits mean fewer significant digits are preserved. Despite this, it performs well in practice for training large models.

  • bfloat16 is ideal for training tasks where:

    • High dynamic range is important
    • Some loss of precision can be tolerated (e.g., in early layers or gradients)
    • Lower memory and compute overhead is desired compared to float32
  • It is commonly used in mixed-precision training, often with accumulations in float32 to improve numerical stability.

  • Use cases:

    • Large-scale model training (e.g., LLMs)
    • TPUs (Google Cloud), and newer GPUs from NVIDIA, AMD, and Intel that support native bfloat16 ops

GPU/TPU Considerations

  • float16:

    • Supported by specialized hardware units (Tensor Cores) in NVIDIA Volta, Turing, Ampere, Hopper, and Blackwell architectures.
    • Offers higher throughput and lower memory usage, especially during inference.
    • Typically used in mixed-precision training with float32 accumulations to maintain stability.
  • bfloat16:

    • Natively supported on Google TPUs, NVIDIA Ampere and newer (e.g., A100, H100), Intel Habana Gaudi accelerators, and select AMD GPUs.
    • Enables high dynamic range similar to float32 while halving memory usage.
    • Increasingly adopted in training large models where float16 may encounter stability issues.
    • Like float16, bfloat16 is also used in mixed-precision training pipelines.
  • float32:

    • Universally supported across all GPU architectures.
    • Offers the best balance between range and precision.
    • Slower and more memory-intensive compared to float16 and bfloat16.
  • float64:

    • Rare in deep learning; primarily used in scientific computing.
    • Most GPUs support it at much lower throughput.
    • Often omitted entirely from inference workloads due to cost.

Comparative Summary

Format Bits Exponent Bits Mantissa Bits Bias Range (Approx.) GPU Usage
float16 16 5 10 15 \(10^{-5}\) to \(10^{4}\) Fast inference, mixed precision
bfloat16 16 8 7 127 \(10^{-38}\) to \(10^{38}\) Training + inference, high range
float32 32 8 23 127 \(10^{-45}\) to \(10^{38}\) Default for training
float64 64 11 52 1023 \(10^{-324}\) to \(10^{308}\) Rare in ML, slow on GPU
  • This foundation in floating point formats prepares us to understand quantization—where bit-widths are reduced even further (e.g., 8-bit, 4-bit, or binary)—to achieve efficient computation with minimal loss in model performance.

Background: Matrix Multiplication in GPUs

  • Efficient matrix multiplication is at the heart of modern deep learning acceleration on GPUs. This section provides a high-level view of how matrix-matrix multiplications are implemented and optimized on GPU hardware, with special focus on tiling, Tensor Cores, and the performance implications of quantization.

  • Matrix multiplications, especially General Matrix Multiplications (GEMMs), are a core computational primitive in deep learning workloads. Whether in fully connected layers, convolutions (via im2col), or attention mechanisms, these operations are executed billions of times during training and inference. As such, optimizing GEMM performance is essential for efficient neural network execution, particularly on GPUs.

  • To execute GEMMs efficiently, GPUs partition the output matrix into tiles. Each tile corresponds to a submatrix of the result and is computed by a thread block. The GPU steps through the input matrices along the shared dimension (K) in tiles, performing multiply-accumulate operations and writing the results into the corresponding tile of the output matrix. The illustration below (source) shows the tiled outer product approach to GEMMs.

Tiled outer product approach to GEMMs

  • Thread blocks are mapped to the GPU’s streaming multiprocessors (SMs), the fundamental compute units that execute instructions in parallel. Each SM can process one or more thread blocks concurrently, depending on the available resources and occupancy.

  • Performance in GPU matrix multiplication is often bounded by one of two factors:

    • Compute (math) bound: When the arithmetic intensity (FLOPs per byte) is high enough that math operations dominate runtime.
    • Memory bound: When the operation requires frequent memory access compared to math operations, limiting throughput.
  • Whether a given GEMM is compute- or memory-bound depends on the matrix dimensions (\(M, N, K\)) and the hardware’s characteristics. For example, matrix-vector products (where either \(M\) = 1 or \(N\) = 1) are typically memory-bound due to their low arithmetic intensity.

  • Modern NVIDIA GPUs include specialized hardware units called Tensor Cores, which are designed to accelerate GEMMs involving low-precision data types such as float16, bfloat16, and int8. Tensor Cores perform small matrix multiplications in parallel and require that the matrices’ dimensions align with certain multiples (e.g., 8 for float16, 16 for int8) to achieve peak performance. For instance, on Ampere and newer architectures like Hopper or Blackwell, aligning dimensions to larger multiples (e.g., 64 or 128 elements) often yields even better throughput.

  • Matrix dimensions that are not aligned to tile sizes lead to tile quantization, where some tiles carry less useful work, reducing efficiency. Similarly, if the total number of tiles is not an even multiple of the number of GPU SMs, wave quantization can cause underutilization. Both effects can significantly degrade performance despite identical algorithmic complexity.

  • To address this, libraries like cuBLAS employ heuristics or benchmarking to select optimal tile sizes, balancing between tile reuse (large tiles) and parallelism (many small tiles). Larger tiles tend to be more efficient due to better data reuse, but may reduce parallel occupancy on smaller problems.

  • In summary, matrix multiplication performance on GPUs is a delicate balance between compute, memory bandwidth, and architecture-aware tiling strategies. Quantization not only affects data representation but also interacts intricately with the underlying matrix multiplication engine and GPU efficiency.

Definition

  • Quantization in the context of deep learning refers to the process of reducing the numerical precision of a model’s parameters (weights) and/or intermediate computations (activations). Typically, models are trained and stored using 32-bit floating-point (float32) precision. Quantization replaces these high-precision values with lower-precision representations—such as 16-bit floating point (float16), 8-bit integer (int8), 4-bit integer, or binary formats in more extreme scenarios. The primary goals are to reduce model size, improve memory and compute efficiency, and accelerate inference—particularly on hardware that supports low-precision arithmetic.
  • Quantization works in part because modern neural networks are highly over-parameterized and often robust to small numerical perturbations. With proper calibration and tooling, lower-precision representations can approximate the full-precision model closely enough for practical deployment.

Types of Quantization

  • There are two general categories of quantization:

    • Floating-point quantization: This reduces the bit-width of floating-point values—for example, converting from float32 to float16 or bfloat16. These formats retain the same general IEEE 754 structure (sign, exponent, mantissa) but use fewer bits, reducing precision and dynamic range. This kind of quantization is primarily used for inference on GPUs and accelerators optimized for low-precision floating-point math (e.g., NVIDIA Tensor Cores).

    • Integer quantization: This maps floating-point values to fixed-point integer representations (e.g., int8 or uint8). This type requires an additional transformation using scale and zero-point to linearly approximate real values using integers, enabling integer-only arithmetic during inference on CPUs and certain edge devices.

Integer Quantization

  • Integer quantization is typically implemented as a learned linear transformation (i.e., linear mapping) defined by two parameters: scale and zero-point.

    • The scale is a floating-point multiplier that determines the resolution or step size between adjacent quantized integer values.

    • The zero-point is an integer offset that aligns a real-valued zero to the corresponding integer value in the quantized (i.e., target) range. This allows for asymmetric distributions, where zero is not necessarily centered.

      • The forward quantization formula (float to int) is:
      q = round(x / scale) + zero_point
      
      • The reverse dequantization formula (int to float) is:
      x = scale * (q - zero_point)
      
    • As an example:

      • Suppose we want to quantize floating-point values in the range \([-1.0, 1.0]\) to 8-bit unsigned integers (uint8, range 0–255). The mapping would look like:
      scale = (max - min) / (quant_max - quant_min) = (1.0 - (-1.0)) / (255 - 0) ≈ 0.00784
      zero_point = round(0 - min / scale) = round(0 - (-1.0 / 0.00784)) = 128
      
      • This means that the floating-point value 0.0 maps to 128, 1.0 maps to 255, and -1.0 maps to 0. Intermediate values are linearly interpolated. This transformation enables low-bit integer operations that approximate floating-point behavior.

Non-Linear Integer Quantization Methods

  • While linear quantization using scale and zero‑point is the most common, non‑linear quantization methods are also employed in integer quantization to better match real-world data distributions. These methods typically apply to integer quantization, as they redefine how integer values map to real numbers:

    • Logarithmic quantization uses exponentially spaced quantization levels—e.g. powers of two—providing better representation across a wide dynamic range. This method is non‑linear and particularly used for integer-only inference pipelines. It is not relevant for floating‑point quantization, which already uses a non‑uniform exponent-based scale inherently built into its representation.

    • K-means or cluster-based quantization groups floating-point values into clusters, mapping each to its centroid—another non-linear approach for integer quantization or weight sharing schemes.

    • Learned transformations, such as LSQ (Learned Step Size Quantization) and its non-uniform variant nuLSQ, optimize quantization step sizes or level spacing via backpropagation. These methods are applied to integer quantization of weights and activations (e.g., 2‑, 3‑, or 4‑bit integer quantization) and involve non-linear quantizer parameterization.

    • In summary, non-linear quantization techniques are relevant for integer quantization workflows, where they redefine integer mapping to better match value distributions. Floating‑point quantization (e.g., float32 \(\rightarrow\) float16/bfloat16), while structurally non-linear due to its exponent/mantissa hierarchy, does not employ these learned or clustering-based non-linear integer mapping schemes.

Floating-Point Quantization

  • Floating-point quantization is implemented by truncating or rounding the mantissa and exponent fields in IEEE754 representation—e.g. conversion from float32 to float16 or bfloat16—preserving the format structure but reducing bit-width. This form of quantization (i.e., bit‑width reduction) is non-linear in effect because the quantization steps vary by exponent range: numbers near zero have finer granularity than large values due to the floating-point exponent scaling.
  • This approach aligns with a well-known model in signal quantization theory often referred to as the compressor–quantizer–expander model. In this framework, the exponential scaling of floating-point numbers acts as a compressor that non-linearly maps real values into a domain where uniform quantization (truncation of mantissa bits) is applied. The quantizer then discretizes the mantissa (hidden quantizer), and the expander step reconstructs the approximate value from the compressed representation. This structure enables efficient representation of a wide dynamic range with relatively coarse quantization, especially benefiting smaller values close to zero.
  • Common APIs include casting methods like model.half() in PyTorch and PyTorch’s support for float16 static quantization configs (e.g. float16_static_qconfig). Floating-point quantization halves memory footprint with minimal accuracy loss on GPU inference platforms.

Dequantization Considerations

  • Dequantization is not always needed during inference, and its necessity depends on the type of quantization and the underlying hardware. In integer-only quantization pipelines—commonly used for inference on mobile CPUs or edge devices—computations are performed entirely in the integer domain (e.g., int8 or uint8), and dequantization is typically only applied at the final stage, such as for logits or output activations. This avoids floating-point operations altogether during inference.
    • However, in hybrid quantization workflows, where some layers are quantized (e.g., to int8) and others remain in higher precision (float32 or float16), intermediate dequantization is required at the layer boundaries to enable compatibility between quantized and non-quantized components. This is common in models that cannot fully tolerate quantization across all layers due to accuracy degradation or unsupported ops.
    • In contrast, when quantizing to lower-precision floating-point formats like float16, dequantization is not needed at all, because these formats are still natively supported by GPU hardware. For example, NVIDIA Tensor Cores are optimized for float16 (and bfloat16) matrix operations, so models using float16 quantization can be executed directly end-to-end without converting back to float32. All computations remain in low-precision floating-point format, maintaining performance while avoiding the complexity of dequantization logic entirely.

Quantization Workflows

  • There are three main workflows/approaches to apply quantization:

    • Dynamic / Runtime Quantization: This method quantizes model weights statically (e.g. to int8 or float16), while activations remain in full precision until runtime, where they are quantized dynamically at each inference step right before computation. It requires no calibration dataset and no fine‑tuning, making it the easiest quantization method provided by PyTorch. It is particularly effective for models dominated by weight‑heavy layers such as torch.nn.Linear, recurrent layers (nn.LSTM, nn.GRU), and transformers. In PyTorch, this is implemented via the function quantize_dynamic, for example:

      import torch
      quantized_model = torch.ao.quantization.quantize_dynamic(
          model_fp32,
          {torch.nn.Linear, torch.nn.LSTM},
          dtype=torch.qint8
      )
      
      • With this approach, the quantized model is memory‑efficient and can accelerate inference for NLP architectures, often with negligible accuracy loss compared to PTQ—though with lower benefit on convolution‑heavy vision models.
      • Look up the Dynamic / Runtime Quantization section for a detailed discourse on this topic.
    • Post-Training Quantization (PTQ): Converts a fully trained high-precision model to a lower-precision format without retraining. PTQ typically uses a calibration dataset to compute appropriate scale and zero-point values using strategies like min-max range or percentile clipping. It is simple to use and well-supported in frameworks such as TensorFlow Lite and PyTorch, but may incur accuracy loss—particularly for sensitive or activation-heavy layers.

      We generally recommend 16-bit floats for GPU acceleration and 8-bit integer for CPU execution.

      • This reflects hardware preferences: float16 enables faster matrix multiplications on GPU accelerators like Tensor Cores, while int8 is more efficient on CPU architectures with dedicated integer units such as integer SIMD extensions (e.g., AVX, NEON, VNNI).
      • Look up the Post-Training Quantization section for a detailed discourse on this topic.
    • Quantization-Aware Training (QAT): In QAT, quantization effects—specifically the non-linearity introduced by rounding and clipping—are simulated during training, allowing the model to adapt. The model behaves as though it operates in lower precision during the forward pass, using the quantization formula above with fake quantization modules (e.g., FakeQuantize in PyTorch or tf.quantization.fake_quant_with_min_max_vars in TensorFlow). These modules apply quantization and dequantization logic using scale and zero-point. However, backpropagation remains in full precision. In other words, gradients and parameter updates are still computed using full float32 precision. This allows the model to adapt to quantization-induced noise, often resulting in better accuracy retention compared to PTQ. QAT is especially useful when targeting very low-bit formats (such as 4-bit or lower) or when quantizing sensitive components like attention layers in transformers or deploying models in high-accuracy applications.

Benefits and Limitations

  • Quantization can lead to significant reductions in efficiency, both in terms of memory footprint and computational load. For example, converting float32 weights to int8 reduces storage requirements by 75%, and on supported hardware, can improve inference speed by 2× to 4×. These benefits are amplified on edge devices with limited memory, power, and compute such as mobile phones, IoT sensors, embedded processors, and accelerators that support low-precision execution.

While quantization offers compelling advantages, it can also suffer from practical limitations. Quantization might not work uniformly well across all architectures or layers since some operators (powering these layers and architectures) might not be supported in quantized form on all hardware targets. Put simply, some hardware can lack native support for certain quantized operators, since they typically have to be implemented individually. Operators like group convolutions, custom layers, normalization approaches (say LayerNorm), etc. may fall back to float32 or require custom low-level kernels, potentially limiting compatibility or efficiency on target platforms. Lastly, layers with small value ranges, heavy outliers, or complex nonlinear interactions may require higher precision (e.g., float16 or float32) to avoid accuracy degradation.

Mitigation Strategies

  • Selective strategies like per-channel, per-group, per-layer, per-tensor, and mixed-precision quantization are commonly used to mitigate limitations such as accuracy loss, hardware incompatibilities, and uneven value distributions across layers.

  • Per-channel quantization (also referred to as channel-wise quantization): This approach assigns separate scale and zero-point values to each output channel of a weight tensor. It is particularly effective in convolutional and linear layers where each output channel (or filter) may have significantly different weight distributions. By capturing channel-wise variations in magnitude, it provides better quantization accuracy, especially in vision models like ResNet, MobileNet, and EfficientNet, as well as transformer-based architectures such as BERT and GPT. PyTorch implements this using the torch.per_channel_affine quantization scheme.

  • Per-group quantization (also referred to as group-wise quantization): A compromise between per-tensor and per-channel quantization. Channels are divided into groups, with each group sharing quantization parameters. This reduces the overhead of storing separate scale/zero-point values for every channel, while still preserving more distributional information than per-tensor quantization. It is particularly useful in resource-constrained deployment scenarios where memory and compute costs must be balanced against accuracy. Though not always exposed via high-level APIs in frameworks like PyTorch, this strategy is supported in certain hardware accelerators and vendor-specific toolchains (e.g., Qualcomm, MediaTek, Xilinx).

  • Per-layer quantization (also referred to as layer-wise quantization): This applies the same quantization parameters (scale and zero-point) across an entire layer’s output or weight tensor. For example, all outputs of a linear layer or all weights of a convolution kernel are quantized using a single shared set of parameters. This method is computationally efficient and requires minimal additional metadata, making it widely used in low-resource settings or for fast prototyping. However, it often leads to higher quantization error in layers with highly varied internal distributions.

  • Per-tensor quantization (also referred to as tensor-wise quantization): A special case of per-layer quantization where quantization is applied uniformly across an entire tensor (typically weight or activation). A single scale and zero-point are calculated for the full tensor, regardless of dimensionality or channel boundaries. This is the simplest and most lightweight quantization method, requiring minimal bookkeeping and fast execution. While effective for layers with narrow and uniform value ranges, it can result in significant information loss when used on tensors with wide dynamic range or uneven channel statistics. This method is often the default in early-stage quantization workflows or on hardware that does not support fine-grained schemes.

  • Mixed-precision quantization: Instead of applying uniform quantization across all model layers, mixed-precision quantization selectively retains higher-precision (e.g., float32 or float16) computation in layers that are sensitive to quantization noise—such as attention heads, layer normalization, or output classifiers. Other layers, particularly early convolution blocks or MLPs, can be safely quantized to int8 or lower. This approach enables developers to achieve a favorable trade-off between accuracy and efficiency. Mixed-precision is supported by most modern inference engines including TensorRT, TVM, XNNPACK, and PyTorch FX Graph Mode Quantization.

Weights vs. Activation Quantization

  • Why they’re not the same: Weights and activations have fundamentally different roles and constraints during quantization.

    • Weights are static once training completes. Since they do not vary across inputs, they may be quantized offline using fixed calibration data or analytically via heuristics such as min‑max scaling or percentile statistics. This allows more aggressive optimization techniques such as per‑channel quantization or non‑uniform quantization, and values may be precomputed and stored in compact low‑precision formats like int8, int4, or binary.

    • Activations, by contrast, are dynamic and input‑dependent. Their value ranges vary based on data processed during inference. Therefore, activation quantization must accommodate runtime variability. Two common modes exist:

      • Static Activation Quantization: Calibration datasets estimate typical activation ranges (min/max or histogram‑based) per layer. These statistics are then used to assign fixed scale/zero‑point pairs for quantized representation, using observers such as MinMaxObserver or histogram‑based observers.

      • Dynamic Activation Quantization: During inference, scale and zero‑point values are computed on‑the‑fly from input‑dependent statistics (e.g. dynamic min/max per batch). This avoids calibration datasets but may add latency due to runtime computation.

  • Handling outliers: Outliers in activation or weight distributions can drastically degrade quantization quality by skewing range and reducing effective precision. Several mitigation strategies include:

    • Percentile‑based Clipping: Rather than using absolute min/max, activations may be clipped to a percentile‑based range (e.g. 99.9%) to discard extreme outliers. Techniques include KL divergence minimization or MSE‑based clipping, used in frameworks such as PyTorch or TensorFlow Lite.

    • Per‑Channel Quantization: Instead of applying a single scale across a tensor, per‑channel quantization assigns unique scale and zero‑point values per output channel. This adapts to local distribution variations, particularly in convolutional or linear layers. In PyTorch, this is implemented via PerChannelMinMaxObserver or similar observers using torch.per_channel_affine schemes. Core functions include torch.quantize_per_channel and torch.fake_quantize_per_channel_affine to simulate quantization.

      • Per‑channel quantization is especially effective in:

        • Convolutional neural networks (e.g. ResNet, MobileNet, EfficientNet, YOLO), where filter/channels have distinct distribution ranges.
        • Transformer architectures (e.g. multi‑head attention or feed‑forward layers), where weight and activation distributions vary widely across heads or projection layers.
    • Learned Scale Factors: Methods like Learned Step Size Quantization (LSQ), introduced in the paper “Learned Step Size Quantization” by Esser et al. (2020), enable scale parameters to be optimized via backpropagation during fine‑tuning. This adaptive scaling is especially beneficial in models with skewed distributions—such as transformer-based language models, where operations like softmax produce skewed outputs.

  • Advanced Weight Handling:

    • Per‑Group Quantization: Instead of full per‑channel (i.e. one scale per channel), per‑group quantization assigns one scale per group of channels (e.g. N channels or per row). This balances granularity and memory overhead and is prevalent in formats like ONNX or TensorRT.

    • Activation‑Aware Weight Quantization (AWQ): AWQ techniques tailor weight quantization ranges and groupings using activation patterns observed during calibration. Rather than uniformly quantizing weights, these methods use sensitivity analysis to allocate bit budgets or adjust grouping for performance‑critical weights.

    • Zero‑Point Optimization: For symmetric quantization (typically weight tensors), zero‑point is fixed at zero. For asymmetric quantization—commonly used for activations—the zero‑point shifts the quantized range. Some frameworks (e.g. ONNX, TensorFlow Lite) allow fine‑grained control over zero‑point alignment, which influences both accuracy and hardware compatibility.

  • In practice, weight and activation quantization are applied jointly but with distinct parameter sets and calibration workflows. Modern toolkits support fine-grained configuration, including:

Effective quantization requires balancing statistical rigor, hardware compatibility, and architecture sensitivity. Activations require runtime awareness, while weights benefit from static optimization—and both may leverage learned or adaptive scaling to maintain fidelity in low‑bit regimes.

Quantization with PyTorch

  • Quantization in PyTorch enables the execution of computations and memory accesses with reduced-precision data types, typically int8, leading to improvements in model efficiency, inference speed, and memory footprint. PyTorch provides comprehensive support for quantization, starting from version 1.3, through an API that integrates seamlessly with the existing eager execution model.

  • Quantized Tensor Representation

    • PyTorch introduces special data types for quantized tensors, enabling the representation of weights and activations in reduced precision (typically int8, and sometimes float16). These tensors can be operated on via quantized kernels available under torch.nn.quantized and torch.nn.quantized.dynamic. These quantized operations allow for a 4× reduction in model size and 2-4× improvements in memory bandwidth and inference latency, depending on the hardware and model structure.

    • Quantization in PyTorch relies on calibration, which is the process of gathering statistics on representative inputs to determine optimal quantization parameters (such as scale and zero-point). These parameters are used in quantization functions of the form round(x / scale) + zero_point, enabling a linear mapping between floating point and integer domains.

  • Quantization Backends

    • PyTorch leverages optimized backend libraries to execute quantized operations efficiently. FBGEMM (Facebook’s GEMM library) is optimized for server environments (x86 CPUs), while QNNPACK is designed for mobile and embedded environments. These are analogous to BLAS/MKL libraries in floating-point computation and are integrated automatically based on the target deployment platform.
  • Numerical Stability and Mixed Precision

    • One challenge in quantization is maintaining numerical stability, particularly for operations involving accumulation or exponentiation. To address this, PyTorch supports mixed-precision training and inference using the torch.cuda.amp module. AMP (Automatic Mixed Precision) allows portions of the model to be cast to torch.float16 while retaining torch.float32 for operations requiring higher precision, improving performance with minimal loss of accuracy. Although initially introduced for CUDA GPUs, mixed-precision techniques are distinct from quantization but can be complementary in certain scenarios.
  • Quantization Techniques in PyTorch

    • PyTorch provides three primary quantization workflows under the torch.quantization namespace, often referred to collectively as “eager mode quantization”:

    • Dynamic Quantization:
      • Weights are statically quantized and stored in int8 format, while activations are dynamically quantized at runtime before computation. This method requires minimal code changes and no calibration data. It is most effective for models dominated by linear layers (e.g., LSTM, GRU, Transformer-based models).

      • Example:

      torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
      
    • Post-Training Quantization:
      • Both weights and activations are quantized. This approach requires calibration, where representative input data is passed through the model to collect statistics via observer modules. Operator fusion (e.g., Conv + ReLU) and per-channel quantization are supported to improve performance and accuracy.

      • Example sequence:

      model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
      torch.quantization.prepare(model, inplace=True)
      # Run calibration with representative data
      torch.quantization.convert(model, inplace=True)
      
    • Quantization-Aware Training (QAT):
      • This technique inserts fake-quantization modules during training, simulating quantization effects in both forward and backward passes. It typically yields the highest post-quantization accuracy, especially in cases where model accuracy is sensitive to quantization noise.

      • Example sequence:

      model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
      torch.quantization.prepare_qat(model, inplace=True)
      # Train model
      torch.quantization.convert(model.eval(), inplace=True)
      
  • Operator and Layer Coverage

    • Quantization support varies by method. Dynamic quantization supports layers like Linear and RNNs, while static and QAT methods support a broader set including Conv, ReLU, BatchNorm (via fusion), and more. FX Graph Mode Quantization (a newer, graph-level approach not covered here) further expands operator support and streamlines workflows.
  • For additional guidance and end-to-end examples, refer to the official PyTorch blog: Introduction to Quantization on PyTorch.

Dynamic / Runtime Quantization

  • Dynamic quantization is one of the most simple quantization techniques in PyTorch, particularly suitable for models where most computation occurs in linear layers—such as transformer models (e.g. BERT) or recurrent networks (e.g. LSTM)—because these operations are dominated by matrix multiplications, which benefit significantly from int8 acceleration without requiring quantized convolutions.

  • In dynamic quantization, model weights are converted from 32-bit floating point (float32) to a lower precision format such as int8 and are permanently stored in this quantized form. Activations, however, remain in float32 format until runtime. At inference time, these activations are dynamically quantized to int8 immediately before the corresponding computation (i.e., matrix multiplication or linear operation) is executed. After the operation, the result is stored back in float32. This hybrid approach enables significant performance gains—such as reduced latency and memory usage—while maintaining reasonable model accuracy.

The aim of dynamic quantization is thus to save compute through faster arithmetic, rather than primarily to reduce storage needs.

  • Unlike static quantization or quantization-aware training (QAT), dynamic quantization requires no calibration dataset or retraining. This makes it ideal when representative data is unavailable or ease of deployment is paramount.

  • Quantization parameters (scale and zero-point) for activations are determined dynamically at each invocation based on the input data range, while weights use fixed scale and zero-point values computed ahead of time. As such, since only the model weights are quantized ahead of time, while activations remain in float32 and are quantized dynamically at runtime based on input data, dynamic quantization is often referred to as a data-free or weight-only quantization method during preparation.

  • PyTorch provides the API torch.ao.quantization.quantize_dynamic(...) for applying dynamic quantization:

    • A model (torch.nn.Module)
    • A specification of target layer types or names (commonly {nn.Linear, nn.LSTM})
    • A target dtype (e.g., torch.qint8).
  • Only supported layer types are quantized—primarily nn.Linear and RNN variants; convolutional layers (e.g., nn.Conv2d) are not supported by dynamic quantization.

  • This approach is particularly effective for transformer and RNN models, where inference throughput is limited by memory-bound weight matrices. For example, quantizing BERT with dynamic quantization often yields up to 4× reduction in model size and measurable speedups in CPU inference latency.

Dynamic Quantization vs. Post-Training Quantization
  • Unlike Post-Training Quantization, dynamic quantization does not use minmax observers or any calibration mechanism. Activation ranges are computed dynamically at runtime based on actual input data, so no observers are required during model preparation.
Example Workflow
  • Below is an example illustrating a typical dynamic quantization workflow:
import torch
from torch import nn
from torch.ao.quantization import quantize_dynamic

# Assume `model` is a pretrained floating‑point nn.Module, in eval mode
model.eval()

# Apply dynamic quantization to Linear and LSTM layers
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},
    dtype=torch.qint8
)

# Run inference: activations will be quantized at runtime
input_data = torch.randn(batch_size, seq_length, feature_dim)
output = quantized_model(input_data)
Workflow Explanation
  • model.eval() ensures deterministic behavior during quantization.
  • quantize_dynamic(...) replaces supported layers with their dynamic quantized implementations.
  • Activations remain in float32 until needed.
  • At runtime, activations are quantized to int8 on-the-fly, and computations are performed with int8 weights and mixed precision accumulators.
  • After the operation, results return to float32.
Typical Benefits
  • Model size reduced by ~75%, thanks to int8 weights.
  • Latency improvements, especially on CPU-bound operations.
  • No calibration or fine-tuning required.
Notes & Trade‑offs
  • Dynamic quantization does not support convolution or custom layers unless manually wrapped.
  • Dynamic quantization handles input distributions that vary widely more gracefully than static quantization, which uses fixed calibration ranges.
  • For CNN models or workloads where activations must also be quantized ahead of time, static quantization or QAT may yield better performance and accuracy.
Further Reading
  • A comprehensive end-to-end tutorial for dynamic quantization on BERT is available here.
  • For a more general example and advanced usage guide, see the dynamic quantization tutorial.
  • The full API documentation for torch.quantization.quantize_dynamic is available here.

Post-Training Quantization

  • Post-Training Quantization (PTQ) is a technique in PyTorch that enables the conversion of a model’s weights and activations from floating-point (typically float32) to 8-bit integers (int8), significantly improving inference efficiency in terms of speed and memory usage. This method is particularly well-suited for deployment scenarios on both server and edge devices, where latency and resource constraints are critical.

  • To facilitate this process, PyTorch inserts special modules known as observers into the model. These modules capture the activation ranges at various points in the network. Once sufficient data has been passed through the model during calibration, the observers record min-max values or histograms (depending on the observer type), which are then used during quantization.

  • A key benefit of static quantization is that it allows quantized values to be passed between operations directly, eliminating the need for costly float-to-int and int-to-float conversions at each layer. This optimization significantly reduces runtime overhead and enables end-to-end execution in int8.

  • PyTorch also supports several advanced features to further improve the effectiveness of static quantization:

    • Observers:

      • Observer modules are used to collect statistics on activations and weights during calibration. These can be customized to suit different data distributions or quantization strategies. PyTorch provides default observers like MinMaxObserver and HistogramObserver, and users can register them via the model’s qconfig.
      • Observers are inserted using torch.quantization.prepare.
    • Operator Fusion:

      • PyTorch supports the fusion of multiple operations (e.g., convolution + batch normalization + ReLU) into a single fused operator. This reduces memory access overhead and improves both runtime performance and numerical stability.
      • Modules can be fused using torch.quantization.fuse_modules.
    • Per-Channel Weight Quantization:

      • Instead of applying the same quantization parameters across all weights in a layer, per-channel quantization independently quantizes each output channel (particularly in convolution or linear layers). This approach improves accuracy while maintaining the performance benefits of quantization.
      • Final conversion to the quantized model is done using torch.quantization.convert.
Post-Training Quantization vs. Dynamic Quantization
  • Unlike dynamic quantization, which quantizes activations on-the-fly during inference, static quantization requires an additional calibration step. This calibration involves running representative data through the model to collect statistics on the distribution of activations. These statistics guide the quantization process by determining appropriate scaling factors and zero points for each tensor.
  • Put simply, in PTQ, while weights are quantized ahead of time, activations are quantized using calibration data collected via observers, enabling fully quantized inference across all layers.
Example Workflow
  • Below is an example illustrating a typical PTQ workflow:
import torch
import torch.quantization

# Step 1: Define or load the model
model = ...  # assume a pre-trained model is loaded

# Step 2: Set the quantization configuration
# Choose backend depending on target device
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')  # for ARM/mobile
# model.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # for x86/server

# Step 3: Fuse modules (e.g., Conv + BN + ReLU)
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])

# Step 4: Insert observer modules
model_prepared = torch.quantization.prepare(model_fused)

# Step 5: Calibrate the model with representative data
# For example, run a few batches of real or synthetic inputs
model_prepared(example_batch)

# Step 6: Convert to a quantized model
model_quantized = torch.quantization.convert(model_prepared)
  • This static quantization pipeline can yield 2× to 4× speedups in inference latency and a 4× reduction in model size, with minimal degradation in accuracy when calibrated effectively.
Workflow Explanation
  • After the example workflow, here is a breakdown of each step and its purpose:

    • Model Preparation: The model must be in eval mode (model.eval()) so that observers and quantization stubs function deterministically. Depending on the backend, model.qconfig = torch.ao.quantization.get_default_qconfig('x86') or 'qnnpack' sets the appropriate quantization configuration.
    • Operator Fusion: Use torch.quantization.fuse_modules to merge modules like Conv‑BatchNorm‑ReLU into a single fused operator. This improves numerical stability and reduces redundant quant‑dequant steps.
    • Observer Insertion: Invoke torch.quantization.prepare to automatically insert observer modules (e.g., MinMaxObserver). These record activation statistics during the calibration phase.
    • Calibration: Run representative real-world input data through the prepared model to collect min/max or histogram statistics via observers. Approximately 100‑200 mini‑batches often suffice for good calibration.
    • Conversion to Quantized Model: Use torch.quantization.convert to replace observed layers with quantized counterparts, applying pre-determined scales and zero points. The resulting model executes end‑to‑end in int8 arithmetic.
Typical Benefits
  • Model size is typically reduced by ≈4× (since int8 requires only 1 byte per parameter instead of 4) and memory bandwidth requirements drop significantly.
  • Inference latency improves—often 2× to 4× faster than FP32—by eliminating repeated float‑int conversions and enabling optimized integer kernels on CPU and mobile.
  • Enables uniform quantized execution across the network, which improves cache locality and enables hardware acceleration on supported platforms.
Notes & Trade‑offs
  • Requires a representative calibration dataset. If the input distribution drifts significantly, fixed quantization ranges may degrade accuracy over time.
  • Slight accuracy loss compared to floating‑point baseline—although typically small (~1‑2%)—especially on highly non-linear or sensitive models. For critical accuracy use‑cases, Quantization Aware Training may be more suitable.
  • Not all operators are supported for eager/static quantization. While convolution, linear, and RNN layers are supported, custom or unsupported layers may need manual handling or fallbacks. Per-channel quantization support is available, but requires proper qconfig settings.
  • The quantization workflow in PyTorch uses either Eager Mode or FX Graph Mode. FX mode can automate fusion and support functional operators, but may require model refactoring. Eager Mode offers more manual control but with limited operator coverage.

Quantization‑aware Training (QAT)

  • QAT is the most accurate among PyTorch’s three quantization techniques for static quantization. With QAT, all weights and activations are subject to “fake quantization” during both forward and backward passes: values are rounded to simulate int8 quantization, while computations remain in floating‑point. Consequently, weight updates occur with full awareness that the model will eventually operate in int8. As a result, models trained with QAT generally achieve higher post‑quantization accuracy than those produced by post‑training quantization or dynamic quantization.

  • The principle is straightforward: the training process is informed about the ultimate quantized inference format. During training, activations and weights are rounded appropriately, so gradient flow reflects the quantization effects. However, the backpropagation itself—the gradient descent—is executed using full‑precision arithmetic.

  • After QAT training and conversion, the final model stores both weights and activations in the int8 quantized format, making it suitable for efficient inference on quantization-compatible hardware.

  • To implement QAT in PyTorch’s eager‑mode workflow, one typically follows these steps:

    1. Fuse suitable modules (e.g. Conv+ReLU, Conv+BatchNorm) via torch.quantization.fuse_modules.
    2. Insert QuantStub and DeQuantStub modules to manage quantization boundaries.
    3. Assign .qconfig to modules—e.g. via torch.quantization.get_default_qat_qconfig('fbgemm') or 'qnnpack'.
    4. Prepare the model using torch.ao.quantization.prepare_qat() or torch.quantization.prepare_qat().
    5. Train or fine‑tune the model in training mode.
    6. After training, apply torch.ao.quantization.convert() or torch.quantization.convert() to produce the fully quantized int8 model.
  • A code snippet in PyTorch invoking QAT:

qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(qat_model, inplace=True)
# train or fine‑tune qat_model ...
quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
  • The fake quantization modules, which simulate the effects of quantization during both forward and backward passes, internally use observers (e.g., MinMaxObserver or HistogramObserver) to track activation and weight ranges during training. These fake quantization modules, which are inserted during training, are typically replaced with real quantized operators in the converted model.
Quantization‑aware Training vs. Post-Training Quantization vs. Dynamic Quantization
  • Note that unlike dynamic quantization (which only quantizes weights statically and activations on-the-fly during inference), QAT simulates quantization for both weights and activations during training. This allows the model to learn parameters that are robust to quantization-induced errors introduced at inference time. Put simply, it allows the parameters to adapt to quantization noise during inference and typically results in significantly better accuracy, especially for models with activation-sensitive layers such as convolutional networks.
  • Furthermore, unlike post-training quantization, QAT does not require a separate calibration phase after training. Instead, it uses the observer modules during training itself to learn and track the necessary quantization parameters, effectively integrating calibration into the training loop.
Example Workflow
  • This sub‑section illustrates a complete workflow for applying static QAT to a convolutional neural network (e.g. ResNet18):
import torch
import torch.nn as nn
import torch.quantization
from torch.quantization import QuantStub, DeQuantStub

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = QuantStub()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(16*32*32, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.relu(self.bn(self.conv(x)))
        x = x.flatten(1)
        x = self.fc(x)
        x = self.dequant(x)
        return x

# 1. Load pre‑trained or float‑trained model
model = MyModel()
model.eval()

# 2. Fuse conv, bn, relu
torch.quantization.fuse_modules(model, [['conv','bn','relu']], inplace=True)

# 3. Attach QAT config
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# 4. Prepare QAT
torch.quantization.prepare_qat(model, inplace=True)

# 5. Fine‑tune QAT model
model.train()
# run training loop for several epochs ...

# 6. Convert to quantized model
quantized_model = torch.quantization.convert(model.eval(), inplace=False)

# 7. Evaluate quantized_model for accuracy and inference performance
Workflow Explanation
  • This workflow enforces quantization effects during training by simulating rounding and clamping via fake quantization modules. QuantStub and DeQuantStub demarcate where data transitions between float and quantized domains. qconfig controls observer placement and quantization schemes (e.g. symmetric vs affine, per‑tensor vs per‑channel).
  • Fake quantization is active during training, guiding the network to adapt to the constraints of int8 inference arithmetic. Only after fine‑tuning does convert() replace fake quant modules with actual quantized operators for efficient int8 inference execution.
Typical Benefits
  • Higher quantized accuracy than post‑training static or dynamic quantization, often reducing performance degradation to minimal levels.
  • Improved robustness to quantization noise, particularly important for convolutional networks and vision models.
  • Retains compression benefits: reduced model size (≈25% of float model) and faster inference on hardware optimized for int8.
Notes & Trade‑offs
  • QAT requires additional training or fine‑tuning, increasing overall development time.
  • Careful scheduling is needed: a small learning rate is recommended to avoid instability introduced by straight‑through estimator (STE) approximations.
  • Model preparation steps such as layer fusion and correct placement of quant stubs are critical. Missing fusions can degrade accuracy.
  • Not all operators or model architectures are fully quantization‑aware; some require manual adaptation.
  • The quantized model behavior may differ subtly from the fake‑quant version: as reported, the output of a real quantized model may diverge slightly from fake‑quant during testing on toy models.

Comparative Analysis

  • Here is a detailed comparative analysis of Dynamic Quantization, PTQ, and QAT:
Aspect Dynamic / Runtime Quantization Post-Training Quantization (PTQ) Quantization-Aware Training (QAT)
Primary Use Case Fast, easy quantization of models with primarily linear operations (e.g. LSTM, BERT) Quantization of convolutional or more complex models with moderate accuracy tradeoff High-accuracy quantization for models sensitive to quantization (e.g. CNNs, object detectors)
Requires Retraining? No No Yes
Requires Calibration Data? No Yes (to collect activation statistics) No separate calibration; statistics are collected during training
When Activations Are Quantized At runtime (dynamically before each operation) Statically using observer statistics from calibration During training using fake quantization modules
Quantization of Weights Done ahead of time (static) Done ahead of time (static) Simulated during training, finalized during conversion
Quantization of Activations Dynamically quantized at inference time Statically quantized using calibration ranges Simulated via fake quantization during training
Typical Accuracy Moderate loss (acceptable for linear-dominant models) Slight to moderate loss Minimal loss; best accuracy among all methods
Complexity of Setup Very low Moderate High
Computation Format During Training Full float32 Full float32 Simulated int8 via fake quantization in float32
Final Inference Format int8 weights, float32 activations outside ops int8 weights and activations int8 weights and activations
Deployment Readiness Easy and quick to apply, suitable for rapid deployment Requires calibration workflow Requires full training pipeline
Main Benefit Faster inference via int8 compute; no data or retraining needed Reduced latency and memory with moderate setup Maximum accuracy with full int8 inference
Target Operators Mostly linear layers (e.g. nn.Linear, nn.LSTM) Broad operator support with fused modules (e.g. Conv+ReLU) Full model coverage with operator fusion and training
Memory Footprint Reduction Partial (activations still float32) Full (weights and activations in int8) Full (weights and activations in int8)
Primary Optimization Goal Compute efficiency (faster matmuls), leading to latency savings Latency and memory savings Accuracy preservation under quantization
Example Usage torch.quantization.quantize_dynamic prepare, calibrate, convert prepare_qat, train, convert

Modern Quantization Techniques

Activation-Aware Weight Quantization (AWQ)

  • Introduced in AWQ: Activation‑aware Weight Quantization for LLM Compression and Acceleration by Lin et al. (2024), AWQ is a post‑training, weight‑only quantization technique tailored for LLMs. It identifies and protects salient channels—those with large activations—via per‑channel scaling derived from calibration, enabling accurate INT4/INT3 quantization without retraining. Reference implementation (AutoAWQ and CUDA kernels) is available on GitHub.
  • AWQ offers a highly practical and accurate low-bit, weight-only quantization path. By using activation statistics to protect the most critical channels, it achieves near full-precision accuracy under INT4/3 quantization without training. It retains the efficiency of group-wise kernels for deployment while minimizing model size and speeding up inference for LLMs suited to edge and GPU environments.
Process
  1. Calibration Pass

    • Run a small calibration dataset through the unquantized model to gather per-channel activation statistics, typically the expected absolute value \(\mathbb{E}[\|x_i\|]\) of input to each weight channel \(i\).
  2. Group-Wise Weight Quantization Baseline

    • Use group size \(G\) (e.g. 32 channels) to quantize weights \(\mathbf{w}\) via a uniform symmetric scheme:
    \[Q(w) = \Delta \cdot \text{Round}\left(w / \Delta\right), \quad \Delta = \frac{\max |w|}{2^{b-1}-1}\]
    • The quantization error in group-wise quantization is proportional to input activation magnitude rather than weight magnitude alone.
  3. Compute Activation-Based Scaling Factors

    • Let \(s_i^{(x)} = \mathbb{E}_x [\|x_i\|]\) represent average per-channel activation magnitude.
    • For scaling exponent \(\alpha \in [0,1]\), define:
    \[s_i^{(\alpha)} = (s_i^{(x)})^\alpha\]
    • Use a small grid search over \(\alpha\) to minimize an approximate MSE:
    \[\mathbb{E}_x \left\lVert Q(\mathbf{w} \cdot \mathrm{diag}(s^{(\alpha)}))\bigl(s^{(\alpha)}\bigr)^{-1}x - \mathbf{w}x \right\rVert^2\]
    • Choose the \(\alpha^*\) that yields the lowest simulated error (no backprop required).
  4. Scale, Quantize, and Fuse

    • Transform weights and activations as:
    \[\tilde w_i = w_i \cdot s_i^*, \quad \tilde x_i = x_i / s_i^* \quad \text{such that} \quad \tilde w^\top \tilde x = w^\top x\]
    • Apply INT4 or INT3 group-wise quantization to \(\tilde w\) using a single scale per group.
    • Fuse the activation scaling \(s_i^{-1}\) directly into the preceding layer normalization or linear layer, avoiding runtime rescaling and preserving inference speed by leveraging existing CUDA kernels from weight-only quantization libraries.
  5. Deployment Optimization

    • AWQ pairs with TinyChat, an inference engine optimized for 4-bit weight-only transformers with kernel fusion and reorder-free dequantization.
    • Uses platform‑aware weight packing to maximize throughput on GPUs (observed ~3× speedup over Hugging Face FP16 with negligible accuracy drop).
Pros
  • Salient-channel preservation: By scaling up high-activation channels, AWQ protects the most influential weights using only ~1% additive precision, significantly reducing quantization error.
  • Training‑less: Requires no finetuning or backpropagation—calibration and closed-form scaling search are sufficient, preserving generalization across domains including instruction-tuned and multi-modal LMs.
  • Hardware‑efficient: Retains group-wise quantization kernels; activation rescaling is fused into existing linear or layer-norm layers, maintaining inference latency and memory efficiency.
Cons
  • Calibration dependency: Requires representative activation samples and search over \(\alpha\), which adds one preprocessing pass but no training.
  • Limited activation reduction: Activations are not quantized (typically kept in FP16), so runtime memory use is not halved.
  • Architecture constraints: Fusion of scaling into preceding layernorm assumes alignment between weight input channels and layernorm channels; may require adaptation for custom architectures.

SmoothQuant

Process
  1. Analyze Activation Outliers: Activation tensors in LLMs often have long-tailed distributions, leading to a high dynamic range. These outliers cause large quantization errors when mapping to low-bit formats like int8.

  2. Offline Scaling of Input and Weights: To reduce activation outliers, SmoothQuant proposes to pre-scale input activations and inversely scale the associated weight matrices before quantization. This is done by computing per-channel maximum absolute values of activation tensors and applying the following scaling transformation:

    \[x_{scaled} = \frac{x}{s} \\ W_{scaled} = W * s\]
    • Here, \(s\) is the scaling factor (per input channel), \(x\) is the activation input, and \(W\) is the weight matrix.
    • This transformation preserves the original linear operation because:

      \[x @ W ≈ (\frac{x}{s}) @ (W * s)\]
  3. Quantize Scaled Tensors: Apply standard post-training quantization (e.g., torch.quantize_per_tensor) on the scaled weight and activation tensors using uniform int8 quantization.

  4. No Runtime Overhead: The inverse scaling factors \(s\) are folded into the preceding layers during offline preprocessing. At inference, the quantized model does not require additional computation to reverse the scaling—hence, preserving speed.

Pros
  • Training-free and calibration-light—requires only a few batches of representative data for statistics.
  • Allows fully static W8A8 quantization for transformers, which were previously hard to quantize due to activation outliers.
  • Compatible with major LLMs like Llama, OPT, BLOOM, and GLM.
  • Hardware-friendly: int8 inference is highly optimized on modern CPUs (e.g., VNNI, AMX) and GPUs.
  • Achieves substantial efficiency improvements:

    • ~2× reduction in memory footprint.
    • ~1.5× to 2× speedup on supported backends.
    • <0.5% accuracy loss on common NLP benchmarks.
Cons
  • Limited to 8-bit formats—does not address extreme quantization (e.g., 4-bit or binary).
  • Effectiveness depends on the distribution of activations. Improper scaling (e.g., due to poor calibration data) may still degrade performance.
  • Static scale determination can be fragile in models with dynamic context (e.g., prompts of variable length).

SpinQuant

  • Introduced in SpinQuant: LLM Quantization with Learned Rotations by Liu et al. (2024), SpinQuant reduces quantization error by applying learned orthonormal rotations to weights, activations, and KV-cache blocks. These rotations normalize tensor distributions, mitigate outliers, and enable accurate W4A4KV4 quantization. The implementation is available on GitHub.
Process
  1. Parameterize rotation matrices:

    • SpinQuant uses blockwise orthonormal rotation matrices initialized using Hadamard, shortcut, or random orthonormal bases.
    • These rotations are applied to groups of weights, activations, and KV-cache blocks (e.g., \(Q\), \(K\), \(V\) matrices or FFN weights), where outliers may exist.
  2. Optimize via Cayley-SGD on the Stiefel manifold:

    • A small calibration set is passed through a simulated W4A4KV4 pipeline.
    • Quantization error (e.g., MSE or KL divergence) between full-precision and quantized outputs is computed.
    • Gradients are backpropagated through the rotation parameters using Cayley-SGD, a method that maintains orthonormality constraints by optimizing directly on the Stiefel manifold.
  3. Fold optimized rotations into model weights:

    • Once optimized, the learned rotation matrices are fused into the model weights and biases (e.g., replacing \(W\) with \(R^T W R\)) during preprocessing.
    • This ensures that no extra computation is introduced at inference time—quantization is performed on the already rotated tensors.
  4. Apply W4A4KV4 quantization:

    • Post-rotation, weights, activations, and KV-cache blocks are quantized to 4-bit using standard uniform quantization schemes.
    • The rotations have distributed the influence of large-magnitude outliers, allowing for a tighter and more efficient quantization range.
Pros
  • Outlier mitigation via distribution normalization:

    • Rotations “smear” large-magnitude values across dimensions, redistributing peak energies that would otherwise dominate quantization bins.
    • This normalization significantly reduces the impact of extreme values and improves low-bit quantization fidelity.
  • Accuracy preservation:

    • Achieves within ~2.9 points of full precision on Llama 2 (7B) zero-shot tasks.
    • Outperforms existing techniques like AWQ, SmoothQuant, and QuaRot by 19–45% in accuracy retention.
  • No runtime overhead:

    • Unlike some quantization techniques that add inference complexity, SpinQuant’s learned rotations are merged offline.
    • At inference, the model behaves identically to a standard quantized model, with no additional compute.
Cons
  • Involves optimization and calibration:

    • Cayley-SGD optimization introduces computational overhead during preprocessing.
    • Requires a small validation set to simulate quantization and compute gradients.
  • Preprocessing complexity:

    • Folding rotations into model weights adds engineering complexity, especially when targeting hardware deployment.
    • Though folded offline, the rotated tensors may have slightly increased numerical range, requiring careful scale selection.
  • Larger intermediate tensors:

    • While inference cost remains low, merged rotated weights can increase storage slightly due to loss of weight sparsity or alignment.

AWEQ: Activation‑Weight Equalization

  • Introduced in AWEQ: Post‑Training Quantization with Activation‑Weight Equalization for LLMs (Nov 2023) by Li et al., AWEQ is a training‑free post‑training quantization technique designed to facilitate both ultra‑low‑bit and 8‑bit weight+activation (e.g., W8A8) quantization in large language models such as LLaMA and OPT. It works by shifting quantization hardness from activations to weights to reduce error.
  • AWEQ effectively balances activation and weight ranges channel-wise via per-channel equalization followed by uniform quantization, incorporating bias correction to reduce residual errors. It achieves significantly improved quantization accuracy—especially for W8A8 floating‑point alternatives—without any training or runtime slow-down, making it an excellent choice for production deployments requiring both efficiency and fidelity.
Motivation
  • Large‑scale LLM activations often exhibit long‑tailed per‑channel distributions with large outliers, making activation quantization challenging even at 8 bits. AWEQ addresses this by balancing activation and weight ranges so that differences in range (and therefore quantization difficulty) are harmonized channel‑wise, reducing wastage in the quantization grid and improving uniform quantization performance.
Process (Implementation Overview)
  1. Channel Range Analysis

    • Run forward passes over a small calibration dataset to compute per‑channel activation range \(r(X)_i = \max(X_i) - \min(X_i)\) and weight range \(r(W)_i\) for each linear or attention block tensor. Range refers to max minus min values across all elements in that channel.
  2. Equalization Factor Computation

    • Compute a scale vector \(s \in \mathbb{R}^C\) to equalize ranges via:

      \[\tilde{X}_i = \frac{X_i}{s_i}, \quad \tilde{W}_i = s_i \cdot W_i\]
      • The objective is typically set such that \(r(\tilde{X}_i) = r(\tilde{W}_i)\) for all \(i\), thereby maximizing per‑channel quantization precision as defined through product of normalized ranges (see Equations 3–10 in the original text).
  3. Tensor Scaling (Fusion)

    • Apply channel‑wise scaling at the input boundaries of transformer modules—such as prior to self‑attention key/value/data and FFN layers.
    • Merge activation scaling into preceding layers (e.g., LayerNorm or Linear) to eliminate runtime overhead. For example, transform internal \(W \leftarrow \mathrm{diag}(s) \, W\); hence, activations use quantizable ranges without additional scaling logic.
  4. Uniform Quantization

    • Quantize the equalized tensors using per‑tensor uniform affine quantization (e.g., 4‑bit or 8‑bit symmetric). Activation quantization thresholds can be fused with the input block for efficient inference.
  5. Quantization Bias Correction (BC)

    • Because quantization after scaling and symmetric clipping can introduce a bias \(\epsilon = W_f - W\), AWEQ applies post‑hoc bias correction:
    \[\tilde{y} = y_e - \epsilon \cdot \mathbb{E}[x]\]
    • where \(\mathbb{E}[x]\) is estimated over calibration data, and \(y_e = (W + \epsilon)x\). This corrects the expected error per layer without changing runtime performance, enhancing stability in deep LLMs without BatchNorm.
Pros
  • Training‑free, with no need for quantization-aware training or gradient-based fine-tuning.
  • Supports both W8A8 activation quantization and ultra‑low-bit weight-only quantization, including INT3/4 with robust performance.
  • Hardware-friendly, as it avoids dynamic scaling during inference; changes are statically fused before deployment.
  • Demonstrates best-in-class accuracy on tasks such as zero‑shot Llama 2 7B evaluation (e.g., average: 70.38% over PIQA, HellaSwag, WinoGrande, ARC‑e—all within <0.01 absolute from FP16).
Cons
  • Requires representative calibration data to compute statistics and activation range profiles.
  • Per‑tensor quantization may not perform as well as per‑channel for certain weight distributions, though the equalization helps mitigate this.
  • Slight overhead in computing equalization factors and bias expectations at quantization time (calibration phase only).

FPTQuant

Process
  1. Function-Preserving Activation Transforms:

    • FPTQuant applies mathematically invertible transforms to the activations in attention and feedforward blocks to reduce their dynamic range. These include:

      • Logarithmic transforms: Applied to soften the long-tailed distributions (especially in attention scores or MLP activations).
      • Affine or exponential transforms: Normalize activations without changing the computation graph logic.
  2. Merging Transforms into Weights:

    • Since these transforms are invertible, the effect can be canceled out by adjusting the downstream linear weights.
    • Specifically:

      • Let \(x \rightarrow f(x)\) be the transform applied to activations.
      • Then \(Wx \rightarrow Wf^{-1}(f(x))\) ensures the output remains unchanged.
      • FPTQuant modifies the linear projection weights accordingly so that the transform step is absorbed and the forward function is preserved.
  3. Quantization Step:

    • With the dynamic range compressed, 4-bit symmetric per-channel quantization is applied to the adjusted weights using PTQ methods.
    • Activations are not explicitly quantized, but their transformed form is compatible with W4A16 runtimes (e.g., vLLM) where only weights are quantized.
  4. No Runtime Penalty:

    • All transforms are resolved offline and merged into weights.
    • The runtime model is a standard quantized model with no extra ops introduced during inference.
Pros
  • Fully training-free, invertible, and architecture-agnostic.
  • Achieves INT4 weight-only quantization with minimal or no accuracy loss, by preserving function through mathematically exact transformation.
  • Compatible with W4A16 systems like vLLM, delivering significant memory and latency improvements without major architectural rework.
Cons
  • Primarily targets weights—does not quantize activations directly, limiting total memory benefits.
  • Still an emerging method—performance and generalization are under ongoing validation across LLM variants (e.g., Mistral, Gemma).
  • May require per-layer transform tuning based on architecture layout (e.g., attention vs MLP blocks).

Palettization (Weight Clustering)

  • Palettization, also known as weight clustering, is a quantization scheme that replaces full‑precision weights with low‑bit indices into a small lookup table (LUT). Each weight value is approximated by the nearest centroid in the LUT, enabling efficient storage and retrieval.
Process
  1. Clustering: Collect all float‑format weights for a layer (or group of layers), then run k-means clustering to derive a set of centroids (typically, \(2^{n}\) entries for n‑bit palettization).
  2. Index Mapping: Each weight is replaced with an integer index pointing to its closest centroid in the LUT. The original full‑precision value is no longer stored.
  3. Granularity Options:
    • Per‑tensor granularity: A single LUT for the entire weight tensor.
    • Per‑group‑channel granularity: The tensor is divided into groups of channels (defined by group_size), each with its own LUT—offering a better accuracy/compression trade-off.
  4. Optional Vector Clustering (cluster_dim > 1): Enables multi-dimensional centroids by clustering weight vectors instead of scalars, improving approximation quality for some architectures.
  5. Post‑processing: Optionally quantize LUT centroids themselves to a lower precision (e.g. int8) for additional compression.
Integration in Workflows
  • Available via Apple’s coremltools.optimize.torch.palettization API, which injects FakePalettize layers into the model for palettization-aware training (PAT).
  • During training, k‑means clustering is applied online, and the LUT and indices are learned through gradient steps. After convergence, the finalize() call folds LUTs and indices into permanent quantized weights.
When to Use Palettization
  • Memory-critical deployment: Edge devices or mobile apps where weight storage is the bottleneck.
  • Aggressive compression: Scenarios requiring sub‑4‑bit representation.
  • Architecture flexibility: Works with both CNNs and transformers when standard affine quantization struggles.
  • Fine‑tunable deployment targets: Fine-tuning after palettization enables high accuracy while still achieving significant compression ratios.
Pros
  • Extreme compression: Supports ultra-low bit‑width representations (e.g. 2–4 bits) for weights.
  • Memory savings: Offers major memory savings—each weight becomes a small index instead of a float.
  • Vector clustering: Multidimensional centroids can preserve structure in weight matrices.
  • Flexible granularity: Per-tensor, per-group, or vector-level control enables tailored compression vs. accuracy trade-offs.
  • Compatible with fine‑tuning: Compatible via PAT, allowing retention of accuracy through fine‑tuning post-clustering.
Cons
  • Requires additional training or fine‑tuning steps (PAT) to compensate for quantization error.
  • Clustering and LUT management adds complexity to both training and inference pipelines. In other words, introduces LUT metadata and integer-to-centroid lookup logic in inference.
  • Larger runtime overhead than standard affine quantization, especially with per-channel or per-group palettization which increases storage overhead for multiple LUTs and adds runtime logic to look up indices.
  • Less intuitive and more complex to implement than simple scale-based quantization.

Comparative Analysis

Method Bits What Quantized Training‑Free? Key Innovation Accuracy Retention
Uniform Quantization 4–8 bit Weights ± Activations Simple affine mapping, per-tensor or per-channel Good for smooth distributions (~2 pt drop)
AWQ W4 Weights only Activation‑aware scaling based on calibration High (>FP16)
SmoothQuant W8A8 Weights + Activations Migrates activation outlier difficulty to weights Very High (<0.5% loss)
SpinQuant W4A4KV4 Weights, Activations, KV-cache Partially Learned orthonormal rotations to normalize distributions Best in class (within ~3 pt)
AWEQ W4 or W8 Weights + Activations Scale equalization between weight and activation ranges Strong accuracy retention
FPTQuant W4 Weights only Invertible transforms on activations to preserve function Excellent (minimal loss)
Palettization (Clustering) 2–4 bit (index representation) Weights only (LUT + index mapping) ✘ (requires Palettization-Aware Training or fine‑tuning) K‑means or differentiable clustering + lookup table for weight sharing Tunable via PAT (often strong with tuning)

Device and Operator Support across Frameworks

PyTorch

  • Quantization in PyTorch is supported for a limited subset of operators, and the availability of these operators depends on the specific quantization approach being employed—static, dynamic, or QAT. The list of supported quantized operators is not exhaustive and evolves with newer PyTorch releases. For an up-to-date reference, consult the official PyTorch quantization documentation here.

  • The implementation of quantization in PyTorch is backend-dependent, meaning that both the quantization configuration (which defines how tensors are quantized) and the set of quantized kernels (which define how arithmetic is performed on quantized tensors) vary based on the target hardware. Currently, PyTorch provides official support for quantized inference only on CPUs, specifically for x86 and ARM architectures. These are supported via two primary backends:

    • fbgemm: Optimized for server-class x86 CPUs.
    • qnnpack: Designed for mobile ARM CPUs.
  • The backend must be explicitly set to ensure compatibility between the model’s quantized representation and the runtime kernels, as shown in the example below:

    import torch
    
    backend = 'fbgemm'  # 'qnnpack' for ARM/mobile inference
    my_model.qconfig = torch.quantization.get_default_qconfig(backend)
    
    # Prepare and convert model
    # Set the backend on which the quantized kernels need to be run
    torch.backends.quantized.engine = backend
    
    # Continue with model preparation and conversion steps...
    
  • QAT in PyTorch is performed in full-precision (float32) mode to leverage existing GPU or CPU hardware during training. This technique simulates the effects of quantization during training to improve model robustness when deployed with quantized weights. QAT is particularly beneficial for convolutional neural networks (CNNs), especially lightweight models such as MobileNet, where static or dynamic post-training quantization may result in unacceptable accuracy degradation.

Integration in torchvision
  • The torchvision library includes integrated support for quantization in several widely used neural network architectures. These include GoogLeNet, InceptionV3, ResNet (various depths), ResNeXt, MobileNet (V2 and V3), and ShuffleNet. The support is provided in three distinct forms to enable a range of workflows:

    1. Pre-trained quantized model weights: These models are fully quantized and can be used directly for inference without additional fine-tuning.
    2. Quantization-ready model definitions: These are versions of the models with quantization stubs pre-inserted, making them suitable for post-training quantization or QAT.
    3. QAT scripts: Scripts are available to perform QAT on supported models. While these scripts are applicable to all the models listed above, empirical evaluations show that QAT tends to yield significant accuracy benefits primarily for lightweight models like MobileNet.
  • PyTorch also provides a dedicated tutorial demonstrating how to perform transfer learning with quantization using pre-trained models from torchvision. This enables developers to take advantage of quantized inference while still adapting models to custom datasets and deployment scenarios.

Resources

TensorFlow

  • Quantization support in TensorFlow is primarily centered around deployment through TensorFlow Lite (TFLite). Only a subset of TensorFlow operations are supported in quantized form when converting models to run efficiently on edge devices. For a comprehensive list of quantization-compatible operators, refer to the official TFLite operator compatibility documentation here.

  • Quantized inference in TensorFlow is enabled through TensorFlow Lite (TFLite) delegates, which provide optimized execution across various hardware backends. These include:

    • CPU Delegate: Supports int8 and float16 quantized models using XNNPack kernels, which are enabled by default in modern TFLite runtimes.
    • GPU Delegate: Accelerates inference on mobile and embedded GPUs. It supports float16 quantization and, in limited cases, int8 precision. The delegate is available on both Android and iOS platforms.
    • NNAPI Delegate (Android only): Interfaces with on-device hardware acceleration drivers. Quantized int8 models are typically supported and can see performance improvements depending on the device and vendor-specific drivers.
    • Edge TPU Delegate: Targets Google’s Coral hardware and supports only fully integer quantized models with int8 weights and activations. Due to strict operator and quantization constraints, models must be carefully converted and then compiled using the Edge TPU Compiler.

    • The level of operator support and performance characteristics differ by delegate. For example, the Edge TPU requires that all operations be quantized and supported by its limited op set. Any unsupported operations will result in compilation failure or will require fallback to CPU, which can significantly affect performance. As such, developers must validate operator compatibility prior to deployment by reviewing the TFLite ops compatibility guide and testing with their target delegate.
  • QAT in TensorFlow is implemented using the tfmot.quantization.keras.quantize_model API available through the TensorFlow Model Optimization Toolkit (TFMOT). Similar to PyTorch, QAT in TensorFlow is performed in floating point, allowing the model to simulate quantized behavior during training while still leveraging GPU acceleration. This helps preserve accuracy for models that do not respond well to post-training quantization, such as compact architectures like MobileNet or custom CNNs. The general trade-offs between PTQ and QAT in TensorFlow align closely with those in PyTorch, although some feature and operator support mismatches still exist between the two frameworks.

  • When using post-training quantization or QAT, it’s important to validate that all critical model operations are supported in TFLite with quantized equivalents. Unsupported operations may be automatically left in float, potentially degrading the intended performance benefits of quantization.

Integration in tf.keras.applications
  • While TensorFlow does not provide pre-quantized models in tf.keras.applications, the Model Optimization Toolkit provides utilities to quantize these models post-training or prepare them for QAT. Developers can load a model from tf.keras.applications, apply quantization via TFMOT, and then convert it to TFLite. The process typically involves:

    1. Cloning the model with quantization-aware layers using quantize_model.
    2. Fine-tuning the quantized model if needed.
    3. Converting the trained model to TFLite using the TFLiteConverter.
  • TFLite provides tools and guidelines for performing transfer learning with quantized models, though, as with PyTorch, QAT tends to be necessary mainly for accuracy-sensitive lightweight models.

Resources

CoreML

  • PyTorch-based quantization might not necessarily work in other production environments. In particular, when converting to Apple’s CoreML format, you need to just use their quantization (which might be limited to just 16-bit quantization). When using edge devices, be careful to check that quantization is possible (in Apple’s case the hardware is already computing everything in fp16 on GPU, so you only save possibly the memory footprint of the network’s weights).

  • TensorFlow has a similar set of steps as above, though the examples are focused on TFLite. Essentially, static and dynamic quantization are explained in the Post-training quantization page, and there’s a QAT page. The trade-offs appear to be very similar, though there’s always some feature mismatch between PyTorch and TF.

Choosing the right quantization approach

  • The choice of which scheme to use depends on multiple factors:

    • Model/Target requirements: Some models might be sensitive to quantization, requiring QAT.
    • Operator/Backend support: Some backends require fully quantized operators.
  • Currently, operator coverage in PyTorch is limited and may restrict the choices listed in the table below. The table below from PyTorch: Introduction to Quantization on PyTorch provides a guideline.

Performance Results

  • Quantization provides a 4× reduction in the model size and a speedup of 2× to 3× compared to floating point implementations, depending on the hardware platform and the model being benchmarked. The table below from PyTorch: Introduction to Quantization on PyTorch offers some sample results:

Accuracy results

  • The tables below from PyTorch: Introduction to Quantization on PyTorch compares the accuracy of static quantized models with the floating point models on Imagenet. For dynamic quantization, we compared the F1 score of BERT on the GLUE benchmark for MRPC.

Computer Vision Model accuracy

Speech and NLP Model accuracy

How far can we go?

  • Apparently down to 1 bit! There have been several attempts over the years to create binary neural networks if you want the most extreme version of the accuracy vs. speed tradeoff. For the most part, these are still research projects rather than usable ideas, though XNOR-Net++ seems to have been implemented in PyTorch.

Use-case

  • Quantization’s goal is to increase inference speed. (In contrast, as we’ll see in the section on Mixed Precision Training, Automatic Mixed Precision (AMP)’s main goal is to reduce training time.)

Further Reading

Knowledge Distillation

  • Knowledge distillation is a model compression technique in which a smaller, lightweight student model is trained to replicate the behavior of a larger, typically pre-trained teacher model. Introduced in Distilling the Knowledge in a Neural Network by Hinton et al. (2015), this approach allows for the deployment of computationally efficient models that maintain much of the predictive power of their larger counterparts.

The central premise is to transfer the “dark knowledge” captured in the output distributions (logits) of the teacher model to the student model. These softened outputs provide richer learning signals than traditional one-hot labels, capturing inter-class similarities and uncertainty.

  • By learning from soft labels or intermediate representations, student models can inherit the generalization ability of much larger teachers—often improving latency, storage, and robustness in real-world deployment. But success hinges on careful calibration of student capacity, distillation type, and supervision strategies.

  • The image below (source) illustrates the concept of Knowledge Distillation:

Mechanism

  • In standard supervised learning, models are trained to predict hard, one-hot labels. In contrast, knowledge distillation augments this with soft targets—probability distributions produced by the teacher. These targets encode relative likelihoods across all classes, providing richer supervisory signals.

  • The student is trained using a composite loss function that combines:

    • Soft target loss: Kullback-Leibler (KL) divergence between the teacher’s and student’s softened output distributions.
    • Hard target loss: Standard cross-entropy loss against ground-truth labels.
  • The loss function is:

\[L = \alpha \cdot L_{\text{hard}} + (1 - \alpha) \cdot L_{\text{distill}}\]
  • where:

    • \(L_{\text{hard}}\) is the cross-entropy loss with hard/true labels.
    • \(L_{\text{soft}}\) is the KL divergence between the temperature-scaled softmax outputs:
    \[L_{\text{soft}} = \text{KL} \left( \text{Softmax}(z_{\text{teacher}} / T) \,\|\, \text{Softmax}(z_{\text{student}} / T) \right)\]
    • where \(T\) is the temperature parameter that controls the softness of the distribution. A higher \(T\) yields softer probabilities, revealing class similarities.
  • This dual-loss formulation helps the student generalize better by aligning both label fidelity and model semantics.

Types of Knowledge Distillation

Response-Based Distillation

  • The most common and classical form.
  • The student is trained to match the final output probabilities of the teacher model.
  • Computationally simple and widely adopted.
  • Used in frameworks like DistilBERT, TinyBERT, etc.

Feature-Based Distillation

  • The student mimics internal hidden states or feature representations of the teacher.
  • Often involves aligning intermediate activations across corresponding layers.
  • Useful in vision tasks where spatial features are important.
  • Examples include FitNets: Hints for Thin Deep Nets by Romero et al. (2015).

Relation-Based Distillation

  • Focuses on matching relationships (e.g., distance, similarity, or angle) between samples in feature space across models.
  • Encourages the student to learn the structural knowledge encoded in the teacher’s representation space.
  • Often used in metric learning and ranking tasks.
  • Example: Relational Knowledge Distillation by Park et al. (2019).

Distillation Modes

Offline Distillation

  • The teacher is pre-trained and fixed.
  • The student is trained using the frozen teacher’s outputs.
  • This is the most common paradigm in industry.

Online Distillation

  • Teacher and student are trained simultaneously.
  • The teacher may itself be evolving (e.g., ensemble of students).
  • Allows for dynamic refinement of knowledge but adds training complexity.

Self-Distillation

  • The teacher and student share the same architecture.
  • The teacher is typically an earlier version of the student (e.g., exponential moving average of weights).
  • Demonstrated to improve performance even without model compression.

Why Use Knowledge Distillation Instead of Training Small Models from Scratch?

  1. Richer Supervision leading to Enhanced Generalization: The soft labels/targets from the teacher offer added supervisory signal by encoding subtle inter-class similarities (e.g., cat vs. tiger) and uncertainties that hard labels miss. This can guide the student to generalize better.
  2. Data Augmentation Effect: The additional information in soft labels effectively augments the supervision signal without needing more data.
  3. Performance Boost: Student models trained via distillation often outperform the same architecture trained directly on hard labels.
  4. Compression with Retention: Distillation enables substantial reduction in model size and latency, with minimal loss in accuracy.
  5. Regularization Effect: Soft labels (and the resulting dense supervision) lead to smoother gradients, which can act as a form of regularization, improving robustness.
  6. Data Efficiency: The student often requires fewer training epochs and can converge with less labeled data.
  7. Architecture Agnosticism: Students need not replicate the teacher’s structure, offering flexibility in design.
  8. Latency Reduction: Distilled students exhibit significant inference speedups, sometimes halving latency.

Why Knowledge Distillation Works

  • Soft targets: Soft targets offer weak but informative signals, guiding the student toward nuanced generalizations. Suppose a large teacher model is trained on CIFAR-100 and has convolutional filters that respond to features like pointy ears. When shown a Batman mask labeled “mask,” the model might still activate its cat filters slightly. This leads to a 0.1 probability for “cat.” Training the student on this soft distribution imparts a weak but useful signal that improves its understanding of cat features.

  • Ensemble effects: If an ensemble of models each captures different features—say, one detects pointy ears, another whiskers—distillation into a single student helps consolidate these distinct patterns, enhancing generalization.

  • Multiple views and theoretical foundations: Distillation behaves like weak supervision or multi-view learning. As explored in Towards Understanding Ensemble, Knowledge Distillation and Self-Distillation in Deep Learning by Allen-Zhu et al. (2023), distilled students can approximate ensemble behavior using soft targets.

Distillation in Practice

  • Knowledge distillation intersects with adversarial robustness, privacy preservation, and transfer learning.

  • Most widely used form is response-based distillation, but feature-based and relation-based variants are active research areas.

  • Implementation nuances—teacher-student architecture differences, scheduling, or layer alignment—often require trial and error.

  • In On the Efficacy of Knowledge Distillation, Cho and Hariharan (2019) found that large teacher models can hurt student performance if there’s a capacity mismatch. Bigger is not always better.

  • In Improved Knowledge Distillation via Teacher Assistant, Mirzadeh et al. (2019) emphasized that the gap in capacity between teacher and student should be moderate for best results.

  • Thus, a practical takeaway is to perform offline, response-based distillation using a slightly smaller student model for performance gains with minimal tuning.

  • Recent work such as Well-Read Students Learn Better by Turc et al. (2019) shows that Pre-trained Distillation (PD)—pretraining compact models before distillation—yields better results in NLP tasks. The recommended 3-step process is:

    1. Pre-train the compact model (student) on the same masked language modeling (MLM) objective used in BERT.
    2. Distill from a large, task-specific teacher model using response-based offline distillation. For example, if the downstream task is Natural Language Inference (NLI), use the teacher to produce logits for each class (entailment, contradiction, neutral), and minimize KL divergence with the student’s logits.
    3. Fine-tune the distilled student on the task-specific dataset, such as training on the CoNLL 2003 dataset for Named Entity Recognition (NER).
    • This procedure has the advantage of being architecture-agnostic, making it practical for real-world deployment where the student architecture may differ substantially from the teacher.
  • Pretrained distilled models are now widely accessible. In NLP, libraries such as Hugging Face provide compact and distilled versions like DistilBERT and TinyBERT, often with task-specific checkpoints. In computer vision, Facebook’s d2go and DeiT offer mobile-ready image classification models that were distilled from larger vision transformers.

  • Practitioners should consider leveraging these pretrained distilled models when seeking lower latency or deployment efficiency, especially when retraining from scratch is resource-prohibitive.

Reverse Distillation

  • In reverse distillation, a small model acts as the teacher and a larger model is the student.

  • Particularly useful in noisy datasets (e.g., CTR prediction) where large models overfit easily.

  • In Investigating the Impact of Model Width and Density on Generalization in Presence of Label Noise, Xue et al. (2024) demonstrates this technique in high-label-noise regimes. Particularly, they show that using soft labels from a smaller model can regularize and stabilize large model training under label noise.

  • Process:

    1. Train a small, clean model on a curated subset.
    2. Use it to produce soft labels for the noisy data.
    3. Train the large model using these regularized targets.

Weak Supervision via Distillation

  • Distillation enables semi-supervised learning. Here’s the procedure:

    1. Train a high-capacity teacher on a small labeled set.
    2. Use it to label a larger unlabeled dataset.
    3. Train a compact student model on the combined data.
  • This approach has been successfully used in real-world settings such as Lessons from building acoustic models with a million hours of speech by Parthasarathi and Strom (2019), which trained acoustic models with over a million hours of speech data.

Limitations and Challenges

  • Capacity Gap: On the Efficacy of Knowledge Distillation by Cho and Hariharan (2019) demonstrates that extremely large teacher models can be poor mentors if the student’s capacity is too low—it cannot mimic the teacher effectively. Put simply, if the student model is too weak, distillation may degrade performance. Early stopping of the teacher may yield less “over‑fitted” soft labels that are more deliverable to the student.

  • Architecture Mismatch: Effectiveness can be reduced when the teacher and student architectures differ substantially.

  • Poor transfer on complex tasks: Results on datasets like ImageNet often trail behind simpler benchmarks unless carefully tuned.

  • Higher tuning cost: In contrast to quantization and pruning, distillation often requires more experimentation and task-specific adaptation.

Model Pruning

  • Model pruning is a technique for compressing and accelerating deep neural networks by eliminating redundant parameters—either individual weights or structured components such as entire neurons, filters, or layers—without significantly degrading model performance. This compression facilitates faster inference, lower memory usage, and often improved generalization when carefully applied.

Formal Definition

  • Let \(f(x; \theta)\) be a trained neural network model parameterized by \(\theta \in \mathbb{R}^n\). Pruning aims to construct a sparsified parameter vector \(\theta' \in \mathbb{R}^n\) such that:

    \[\theta'_i = \begin{cases} 0, & \text{if } i \in P \\ \theta_i, & \text{otherwise} \end{cases} \quad \text{where } P \subset \{1, \ldots, n\}\]
    • and the pruned model \(f(x; \theta')\) satisfies:
    \[L(f(x; \theta')) \approx L(f(x; \theta))\]
    • where \(L\) denotes the loss function over the task of interest.

Rationale and Theoretical Motivation

“A dense, randomly-initialized neural network contains a subnetwork that, when trained in isolation, can match the performance of the original network.”

  • This hypothesis supports the idea that training a large model enables discovery of a performant sparse subnetwork, which can be extracted via pruning.

Types of Pruning

Unstructured Pruning

  • Unstructured pruning eliminates individual weights regardless of their position in the weight tensor. It is formally defined by selecting a mask \(M \in \{0, 1\}^n\) applied element-wise to the parameter vector \(\theta\), resulting in a sparse model:

    \[\theta' = M \odot \theta\]
    • where \(\odot\) denotes the Hadamard product. Criteria for zeroing elements commonly include:

      • Magnitude-based pruning (e.g., remove smallest | weights |),
      • Gradient-based importance metrics,
      • Second-order methods (e.g., Hessian-based sensitivity).
  • Although unstructured pruning achieves high sparsity, it often provides limited inference acceleration on conventional hardware due to irregular memory access patterns.

Implementation:

  • PyTorch: torch.nn.utils.prune supports various unstructured pruning strategies.
  • TensorFlow: tensorflow_model_optimization.sparsity.keras allows weight pruning during training.

  • Example:
import torch.nn.utils.prune as prune
prune.l1_unstructured(module, name='weight', amount=0.8)

Structured Pruning

  • Structured pruning removes entire structural components of the model, such as:

    • Filters in convolutional layers,
    • Neurons in fully-connected layers,
    • Heads in transformers,
    • Layers or blocks in residual networks.
  • This results in a reduced model size and faster inference due to preservation of dense matrix formats.

  • Common importance metrics include:

    • L1/L2 norm of filters, proposed in Han et al. (2015),
    • Average activation magnitude,
    • Taylor approximation of loss change,
    • Shapley values.
  • Implementation:

    • TorchPruner (repo) automates structured pruning for linear and convolutional layers.
    • Torch-Pruning (repo) provides advanced pruning methods and dependency tracking across layers.

Pruning Workflow

  • A typical pruning pipeline consists of the following stages:

Step 1: Train the Full Model

  • Train the original model to convergence using standard procedures.

Step 2: Apply Pruning Mask

  • Determine and apply pruning masks using one of the supported strategies. This can occur:

    • Post-training: Prune after the model is fully trained.
    • During training: Gradually prune weights over several epochs.

Step 3: Fine-Tune the Pruned Model

  • Retrain the pruned model to recover lost accuracy. The most effective method is learning rate rewinding, where:

  • Alternatively, weight rewinding may be used to reset weights of surviving parameters to their earlier values (e.g., 1/3 of training completed).

Practical Considerations

Target Sparsity

  • Target sparsity (e.g., 80%) must be tuned experimentally. Aggressive sparsity often requires multiple pruning–fine-tuning cycles.

Compatibility

  • Pruning can be difficult for architectures with:

    • Skip connections (e.g., ResNets),
    • Attention modules with tight dimensional constraints.
  • Custom pruning logic may be required.

Deployment Readiness

  • Sparse inference is not universally supported. Quantization-aware or hardware-specific pruning (e.g., \(N:M\) sparsity) may be necessary for real-world acceleration. As noted in What is the State of Neural Network Pruning? by Blalock et al. (2020), real benefits often depend on framework and hardware support.

Comparative Analysis

Aspect Unstructured Pruning Structured Pruning
Targets Individual weights Filters, neurons, heads, layers
Benefits Compression Compression + inference acceleration
Implementation Ease High Moderate to low
Framework Support TensorFlow, PyTorch TorchPruner, Torch-Pruning
Inference Speedup Limited Significant

Implementing Pruning in PyTorch and TensorFlow

  • Modern deep learning frameworks provide built-in utilities to simplify pruning workflows. Below, we describe practical approaches in both PyTorch and TensorFlow.

PyTorch Pruning

  • PyTorch offers flexible pruning utilities via the torch.nn.utils.prune module. This supports both unstructured and structured pruning.

  • Unstructured Pruning Example (L1-based weight pruning):

import torch
import torch.nn.utils.prune as prune
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# Prune 50% of weights in the first Linear layer
prune.l1_unstructured(model[0], name='weight', amount=0.5)

# To remove the pruning reparameterization and finalize pruning
prune.remove(model[0], 'weight')
  • Structured Pruning Example (entire neuron/channel pruning):
# Prune 30% of channels (columns) in the second Linear layer
prune.ln_structured(model[2], name='weight', amount=0.3, n=2, dim=0)
prune.remove(model[2], 'weight')

TensorFlow Pruning

  • TensorFlow provides pruning support via the tensorflow_model_optimization toolkit. It supports sparsity-aware training by gradually zeroing out weights.

  • Basic Workflow:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Define prunable model
model = tf.keras.Sequential([
    prune_low_magnitude(tf.keras.layers.Dense(256, activation='relu'), 
                        input_shape=(512,)),
    prune_low_magnitude(tf.keras.layers.Dense(10))
])

# Compile and train
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(x_train, y_train, epochs=5)

# Strip pruning wrappers before saving/export
model = tfmot.sparsity.keras.strip_pruning(model)
  • Set Pruning Schedule:
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
    initial_sparsity=0.0,
    final_sparsity=0.8,
    begin_step=0,
    end_step=1000
)
  • TensorFlow Pruning Resources:

Mixed Precision Training

Overview

  • Mixed precision is a technique for substantially reducing neural net training time by performing as many operations as possible in half-precision floating point, float16, instead of the (PyTorch default) single-precision floating point, float32 – it thus involves the use of both 16-bit and 32-bit floating-point types during training to make it run faster and use less memory. By keeping certain parts of the model in the 32-bit types for numeric stability, the model will have a lower step time and train equally as well in terms of the evaluation metrics such as accuracy.
    • The term “numeric stability” refers to how a model’s quality is affected by the use of a lower-precision dtype instead of a higher precision dtype. An operation is “numerically unstable” in float16 or bfloat16 if running it in one of those dtypes causes the model to have worse evaluation accuracy or other metrics compared to running the operation in float32.
  • Today, most models use the float32 dtype, which takes 32 bits of memory. However, there are two lower-precision dtypes, float16 and bfloat16, each which take 16 bits of memory instead. Modern accelerators can run operations faster in the 16-bit dtypes, as they have specialized hardware to run 16-bit computations and 16-bit dtypes can be read from memory faster.
  • Put simply, the idea behind Automatic Mixed Precision (AMP) is that not all layers and operations require the precision of float32, hence it’s better to use lower precision. AMP takes care of what precision to use for what operation. It eventually helps speed up the training.
  • Mixed precision tries to match each op to its appropriate datatype, which as a by-product, can reduce your network’s runtime and memory footprint.

Under-the-hood

  • NVIDIA GPUs can run operations in float16 faster than in float32, and TPUs can run operations in bfloat16 faster than float32. Therefore, these lower-precision dtypes should be used whenever possible on those devices. However, variables and a few computations should still be in float32 for numeric reasons so that the model trains to the same quality (cf. numerical stability in the section on Mixed Precision Overview).
  • Recent generations of NVIDIA GPUs come loaded with special-purpose tensor cores specially designed for fast fp16 matrix operations. Thus, max performance gains are observed on Tensor Core-enabled GPU architectures as we’ll see below in the section on How Tensor Cores Work.
  • However, up until now these tensor cores have remained difficult to use, as it has required writing reduced precision operations into your model by hand. This is where the automatic in automatic mixed-precision training comes in. The [torch.cuda.amp]() API allows you to implement mixed precision training into your training scripts in just five lines of code!

How mixed precision works

  • Before we understand how mixed precision training works, let’s review a little bit about floating point numbers.
  • In computer engineering, decimal numbers like 1.0151 or 566132.8 are traditionally represented as floating point numbers. Since we can have infinitely precise numbers (think \(\pi\)), but limited space in which to store them, we have to make a compromise between precision (the number of decimals we can include in a number before we have to start rounding it) and size (how many bits we use to store the number).
  • Building upon what we discussed in the Background: Precision, the technical standard for floating point numbers, IEEE 754 (for a deep dive please refer to the PyCon 2019 talk “Floats are Friends: making the most of IEEE754.00000000000000002”), sets the following standards:
    • fp64, aka double-precision or “double”, max rounding error of ~\(2^-{52}\).
    • fp32, aka single-precision or “single”, max rounding error of ~\(2^-{23}\).
    • fp16, aka half-precision or “half”, max rounding error of ~\(2^-{10}\).
  • Python uses fp64 for the float type. PyTorch, which is much more memory-sensitive, uses fp32 as its default dtype instead.

The basic idea behind mixed precision training is simple: halve the precision (fp32 \(\rightarrow\) fp16), halve the training time.

  • The hard part is doing so safely.
  • Notice that the smaller the floating point, the larger the rounding errors it incurs. Any operation performed on a “small enough” floating point number will round the value to zero! This is known as underflowing, and it’s a problem because many to most gradient update values created during backpropogation are extremely small but nevertheless non-zero. Rounding error accumulation during backpropogation can turn these numbers into zeroes or nans; this creates inaccurate gradient updates and prevents your network from converging.
  • The 2018 ICLR paper Mixed Precision Training found that naively using fp16 everywhere “swallows” gradient updates smaller than \(2^{-24}\) in value — around 5% of all gradient updates made by their example network:

  • Mixed precision training is a set of techniques which allows you to use fp16 without causing your model training to diverge. It’s a combination of three different techniques.
    1. Maintain two copies of the weights matrix, a “master copy” in fp32, and a half-precision copy of it in fp16. Gradient updates are calculated using the fp16 matrix but applied to the fp32 matrix. This makes applying the gradient update much safer.
    2. Different vector operations accumulate errors at different rates, so treat them differently. Some operations are always safe in fp16, but others are only reliable in fp32. Instead of running the entire neural network in fp16, run some parts in halves and others in singles. This mixture of dtypes is why this technique is called “mixed precision”.
    3. Use loss/gradient scaling. Loss scaling means multiplying the output of the loss function by some scalar number (the paper suggests starting with 8) before performing back-propagation. Multiplicative increases in the loss values create multiplicative increases in gradient update values, “lifting” many gradient update values above the \(2^{-24}\) threshold for fp16 safety. Just make sure to undo the loss scaling before applying the gradient update, and don’t pick a loss scaling so large that it produces inf weight updates (overflowing), causing the network to diverge in the other direction.
  • Combining these three techniques in tandem allowed the authors to train a variety of networks to convergence in significantly expedited time. For benchmarks, please refer the paper.

How tensor cores work

  • While mixed precision training saves memory everywhere (an fp16 matrix is half the size of a fp32 one), it doesn’t provide a model training speedup without special GPU support. There needs to be something on the chip that accelerates half-precision operations. In recent generations of NVIDIA GPUs, there is: tensor cores.
  • Tensor cores are a new type of processing unit that’s optimized for a single very specific operation: multiplying two \(4 \times 4\) fp16 matrices together and adding the result to a third \(4 \times 4\) fp16 or fp32 matrix (a “fused multiply add”).

  • Larger fp16 matrix multiplication operations can be implemented using this operation as their basic building block. And since most of backpropagation boils down to matrix multiplication, tensor cores are applicable to almost any computationally intensive layer in the network.

The catch: the input matrices must be in fp16. If you’re training on a GPU with tensor cores and not using mixed precision training, you’re not getting 100% out of your GPU! A standard PyTorch model defined in fp32 will never land any fp16 math onto the chip, so all of those fp16 cores will remain idle.

  • Tensor cores were introduced in late 2017 in the last-gen Volta architecture, saw improvement in current-gen Turing, and will see further refinements in the still-forthcoming Ampere. The two GPUs generally available on the cloud that support are the V100 (5120 CUDA cores, 600 tensor cores) and the T4 (2560 CUDA cores, 320 tensor cores).
  • One other piece of the puzzle worth keeping in mind is firmware. Although all versions of CUDA 7.0 or higher supports tensor core operations, early implementations are reputedly very buggy, so it’s important to be on CUDA 10.0 or higher.

How PyTorch automatic mixed precision works

  • With that important background out of the way, we’re finally ready to dig into the new PyTorch amp API.

  • Mixed precision training has technically been possible forever: run sections of your network in fp16 manually and implement loss scaling yourself. The exciting thing in automatic mixed-precision training is the “automatic” part. There’s just a couple of new API primitives to learn: torch.cuda.amp.GradScalar and torch.cuda.amp.autocast. Enabling mixed precision training is as simple as slotting these into the right places in your training script!

  • To demonstrate, here’s an excerpt of the training loop for a network using mixed-precision training. # NEW marks spots where new code got added.

self.train()
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, self.max_lr,
    cycle_momentum=False,
    epochs=self.n_epochs,
    steps_per_epoch=int(np.ceil(len(X) / self.batch_size)),
)
batches = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(X, y),
    batch_size=self.batch_size, shuffle=True
)

# NEW
scaler = torch.cuda.amp.GradScaler()

for epoch in range(self.n_epochs):
    for i, (X_batch, y_batch) in enumerate(batches):
        X_batch = X_batch.cuda()
        y_batch = y_batch.cuda()
        optimizer.zero_grad()

        # NEW
        with torch.cuda.amp.autocast():
            y_pred = model(X_batch).squeeze()
            loss = self.loss_fn(y_pred, y_batch)

        # NEW
        scaler.scale(loss).backward()
        lv = loss.detach().cpu().numpy()
        if i % 100 == 0:
            print(f"Epoch {epoch + 1}/{self.n_epochs}; Batch {i}; Loss {lv}")

        # NEW
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()

Loss/Gradient Scaling

  • If the forward pass for a particular op has float16 inputs, the backward pass for that op will produce float16 gradients. Gradient values with small magnitudes may not be representable in float16. These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.
  • To prevent underflow, “gradient scaling” multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are then scaled by the same factor. In other words, gradient values have a larger magnitude, so they don’t flush to zero.
  • The new PyTorch GradScaler object is PyTorch’s implementation of loss scaling. Recall from the section “How mixed precision works” that some form of loss scaling is necessary to keep gradients from rounding down to 0 during training. The optimal loss multiplier is one sufficiently high to retain very small gradients, but not so high that it causes very large gradients to round up to inf, creating the opposite problem.

  • However, there is no one loss multiplier that will work for every network. The optimal multiplier is also very likely to change over time, as gradients are typically much larger at the start of training than at the end. How do you find the optimal loss multiplier without giving the user another hyperparameter that they have to tune?

  • PyTorch uses exponential backoff to solve this problem. GradScalar starts with a small loss multiplier, which every so often it doubles. This gradual doubling behavior continues until GradScalar encounters a gradient update containing inf values. GradScalar discards this batch (e.g. the gradient update is skipped), halves the loss multiplier, and resets its doubling cooldown.

  • Stepping the loss multiplier up and down in this way allows PyTorch to approximate the appropriate loss multiplier over time. Readers familiar with TCP congestion control should find the core ideas here very familiar! The exact numbers used by the algorithm are configurable, and you can read the defaults right out of the docstring:
torch.cuda.amp.GradScaler(
    init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5,
    growth_interval=2000, enabled=True
)
  • GradScalar needs to exert control over the gradient update calculations (to check for overflow) and over the optimizer (to turn discarded batches into a no-op) to implement its behavior. This is why loss.backwards() is replaced with scaler.scale(loss).backwards() and optimizer.step() is replaced with scaler.step(optimizer).

  • It’s notable that GradScalar will detect and stop overflows (because inf is always bad), but it has no way to detect and stop underflows (because 0 is often a legitimate value). If you pick an init_scale that’s too low and a growth_interval that’s too high, your network may underflow and diverge before GradScalar can intervene. For this reason it’s probably a good idea to pick a very large starting value, and with default init_scale=65536 (\(2^{16}\)) that does seem to be the approach that PyTorch is following.

  • Finally, note that GradScalar is a stateful object. Checkpointing a model using this feature will require writing it to and reading it from disk in alongside your model weights. This is easy to do using the state_dict and load_state_dict object methods (covered here in the PyTorch docs).
  • As an implementation detail, note that in PyTorch, each parameter’s gradient (.grad attribute) should be unscaled before the optimizer updates the parameters, so the scale factor does not interfere with the learning rate.

autocast context manager

  • The other half of the automatic mixed-precision training puzzle is the torch.cuda.amp.autocast context manager. Autocast implements fp32 -> fp16 behavior. Recall from “How mixed precision works” that, because different operations accumulate errors at different rates, not all operations are safe to run in fp16. The following screenshots taken from the amp module documentation covers how autocast treats the various operations available in PyTorch:

  • This list predominantly consists of two things, matrix multiplication and convolutions. The simple linear function is also present.

  • These operations are safe in fp16, but have up-casting rules to ensure that they don’t break when given a mixture of fp16 and fp32 input. Note that this list includes two other fundamental linear algebraic operations: matrix/vector dot products and vector cross products.

  • Logarithms, exponents, trigonometric functions, normal functions, discrete functions, and (large) sums are unsafe in fp16 and must be performed in fp32.

  • Looking through the list, it seems to me that most layers would benefit from autocasting, thanks to their internal reliance on fundamental linear algebra operations, but most activation functions would not. Convolutional layers stand out as potentially the biggest winner.

  • Enabling autocasting is pretty simple. All you need to do is wrap the forward pass of your model using the autocast context manager:

with torch.cuda.amp.autocast():
    y_pred = model(X_batch).squeeze()
    loss = self.loss_fn(y_pred, y_batch)
  • Wrapping the forward pass in this way automatically enables autocasting on the backwards pass (e.g. loss.backwards()) as well, so you don’t need to call autocast twice.
  • So long as you follow best practices for using PyTorch (avoiding in-place operations, for example), autocasting basically “just works”.

Multiple GPUs

Mixed Precision with TensorFlow

Performance benchmarks

  • Let’s look at some real-world performance benchmarks over three very different neural networks with and without automatic mixed precision. The training setup involved V100s (last-gen tensor cores) and T4s (current-gen tensor cores), using the Spell API on AWS EC2 instances, p3.2×large and g4dn.xlarge respectively, and a recent PyTorch build with CUDA 10.0.
  • All of the models converged equally, e.g. none of the models saw any difference in training loss between the mixed precision and vanilla network. The networks trained were:
  • The results:

  • Observations from the results:
  • Because the feedforward network is very small, it gets no benefit from mixed precision training.
  • UNet, a medium-sized convolutional model with 7,703,497 total parameters, sees significant benefits from enabling mixed precision training. Interestingly, though the V100 and T4 both benefit from mixed precision training, the benefit to the T4 is much greater: a 5% time save versus a whopping 30% time save.
  • BERT is a large model, and it’s where the time savings of using mixed precision training go from “nice” to “must-have”. Automatic mixed precision will cut training time for large models trained on Volta or Turing GPU by 50 to 60 percent! 🔥
  • This is a huge, huge benefit, especially when you take into account the minimal complexity required — just four or five LOC to your model training script.

Based on the aforementioned training time uplifts, ,ixed precision should be one of the first performance optimization you make to your model training scripts.

What about memory?

  • As explained in the section above on How mixed precision works, a fp16 matrix is half the size of a fp32 matrix in memory, so another purported advantage of mixed precision training is memory usage. GPU memory is much less of a bottleneck than GPU compute, but it’s still pretty valuable to optimize. The more efficient your memory usage, the larger the batch sizes you can fit on the GPU.
  • PyTorch reserves a certain amount of GPU memory at the beginning of the model training process and holds onto that memory for the duration of the training job. This keeps other processes from reserving too much GPU memory mid-training, forcing the PyTorch training script to crash with an OOM error.

  • Here is the impact that enabling mixed precision training has on the PyTorch memory reservation behavior:

  • Interestingly enough, while both of the larger models saw benefit from the swap to mixed precision, UNet benefited from the swap a lot more than BERT did. PyTorch memory allocation behavior is pretty opaque to me, so I have no insight into why this might be the case.

Conclusion

  • Automatic mixed precision training is an easy-to-use and powerful new feature which promises to speed up larger-scale model training jobs running on recent NVIDIA GPUs by up to 60%.
  • While this technique has been around for a while (see e.g. Chip Huyen’s notes on scaling) it’s not been very accessible to the average user because it’s never had a native PyTorch API — until now.
  • To learn more about mixed precision training directly from the source, see the automatic mixed precision package and automatic mixed precision examples pages in the PyTorch master docs.

Key takeways

  • The torch.cuda.amp mixed-precision training module delivers on its promise, delivering speed-ups of 50-60% in large model training jobs with just a handful of new lines of code.

Use-case

  • Automatic Mixed Precision (AMP)’s main goal is to reduce training time. (In contrast, as we saw in the section on Quantization, quantization’s main goal is to reduce inference time.)
  • Refer PyTorch Automatic Mixed Precision Recipe to learn hands-on usage.

Low-Rank Decomposition & Adaptation

Overview

  • Large language models (LLMs) and deep neural networks often rely on massive weight matrices that are expensive to store and compute. However, in many cases, these matrices exhibit redundancy—meaning they can be approximated by lower-rank structures without significantly compromising the model’s predictive performance. Low-rank decomposition exploits this insight by factorizing large matrices into the product of two smaller matrices.

  • The motivation is twofold:

    1. Efficiency: Reducing the number of parameters and operations accelerates training and inference.
    2. Compression: Enables model deployment on memory-constrained devices by lowering storage requirements.
  • This technique is especially attractive when used with techniques like quantization, pruning, or transfer learning, as it complements them without requiring extensive retraining.

Formal Definition

  • The core idea: replace a large dense weight matrix

    \[W \in \mathbb{R}^{d \times d}\]
    • with two smaller matrices:
    \[A \in \mathbb{R}^{d \times r}, \quad B \in \mathbb{R}^{r \times d}\]
    • where \(r \ll d\). The original weight matrix is then approximated as:
    \[W \approx A B\]
  • This reduces the parameter count from \(O(d^2)\) to \(O(rd)\), which is a significant gain when \(r \ll d\). During inference or training, instead of computing:

    \[y = Wx\]
    • we compute:
    \[y = A (B x)\]
  • This allows models to retain much of their representational power while becoming faster and more compact.

Concept

  • Freeze original pretrained weights; add trainable adapters \(A, B\) for each dense layer such that:

  • Efficiency: Only train few parameters (e.g. \(2dr\) per layer). For large LLMs, can reduce fine-tuning parameters by \(10^3 – 10^{4×}\) while achieving competitive results.
  • Popular in federated settings (see above).

Low-Rank Correction for Quantization

  • Counteract quantization-induced error in activation domains. We approximate the full-precision weight matrix:

    \[W \in \mathbb{R}^{d \times d}\]
    • as the product of two low-rank matrices:
    \[W \approx A B, \quad \text{where} \quad A \in \mathbb{R}^{d \times r}, \quad B \in \mathbb{R}^{r \times d}, \quad \text{and } r \ll d\]
    • This reduces the parameter count from \(O(d^2)\) to \(O(rd)\), significantly lowering storage and compute cost.

    • In quantization-aware settings, this is often extended to:

    \[W \approx Q + A B\]
    • where:

      • \(Q\) is the quantized version of \(W\), e.g., INT4 or INT8
      • \(A B\) is a low-rank full-precision correction term on unquantized activations
      • \[A \in \mathbb{R}^{d \times r}, B \in \mathbb{R}^{r \times d}\]
    • This hybrid scheme (quantized base + low-rank residual) allows retaining much of the accuracy of full-precision models while gaining the memory and speed benefits of quantization.

  • Solve via joint optimization: alternating minimization to fit both quantized and low-rank components to minimize output reconstruction error. With ranks at 10% of weight size, activation error gaps can be halved; with 30% rank, closed completely.
  • Fits well with post-training quantization pipelines. Works across calibration sets without full retraining.

Quantized Low-Rank Adaptation Techniques

  • Low-Rank Adaptation (LoRA) techniques allow efficient fine-tuning of large-scale pre-trained models by updating only a small number of additional low-rank matrices while freezing the original weights. Several recent advancements extend this idea by incorporating quantization, yielding memory-efficient and compute-friendly training pipelines. Below, we explore three major variants: LQ-LoRA, QLoRA, and QA-LoRA.

LQ‑LoRA: Quantized + Low-Rank Adaptation

  • LQ-LoRA was introduced in Guo et al., 2023 as a memory-efficient fine-tuning approach that combines low-bit quantization with learnable low-rank adapters.
Overview
  • Each full-precision weight matrix \(W \in \mathbb{R}^{d \times d}\) is decomposed into:
\[W \approx Q + AB\]
  • \(Q \in \mathbb{R}^{d \times d}\): a low-bit quantized matrix (e.g., 2.75 bits), kept frozen
  • \(A \in \mathbb{R}^{d \times r}\), \(B \in \mathbb{R}^{r \times d}\): low-rank full-precision matrices, learnable
Key Properties
  • The quantized base matrix \(Q\) captures the main representational power of the original model, without requiring updates during fine-tuning.
  • The low-rank matrices \(A\) and \(B\) adapt the model to the downstream task.
  • Enables sub-3-bit quantization without major degradation in task performance.
  • Requires only ~27 GB of GPU memory to fine-tune a Llama 2 70B model, enabling large model training on commodity hardware.
Pros
  • High compression without sacrificing much accuracy.
  • Applicable to ultra-large models (e.g., Llama 2 70B).
  • Training only the LoRA adapters ensures stability even at low precision.
Cons
  • Fixed quantized base may limit adaptability in highly domain-shifted settings.
  • Quantization granularity and calibration are critical for performance.

QLoRA: Quantized LoRA with 4-bit Base Model

  • QLoRA, proposed in Dettmers et al., 2023, is an efficient fine-tuning method using a 4-bit quantized model backbone and LoRA adapters.
Overview
  • Applies 4-bit NormalFloat (NF4) quantization to the pretrained weights.
  • Performs double quantization to reduce memory further.
  • Freezes the quantized weights and trains LoRA adapters over them.
  • Uses paged optimizers and activation checkpointing for memory efficiency.
Architecture
\[W_{\text{finetuned}} = \text{Quantize}_{\text{NF4}}(W_{\text{pretrained}}) + \Delta W_{\text{LoRA}}\]
  • Only \(\Delta W_{\text{LoRA}} = AB\) is trained.
  • Entire fine-tuning can be performed in < 24 GB of GPU memory for models like LLaMA‑65B.
Pros
  • Fully open-source and hardware-efficient.
  • Well-established tools in the ecosystem (e.g., Hugging Face peft and bitsandbytes).
  • High accuracy retention even with 4-bit quantization.
Cons
  • Limited to NF4 quantization scheme.
  • No learnability in the quantized weights themselves.

QA-LoRA: Quantization-Aware LoRA

  • QA-LoRA, proposed in Zhang et al., 2023, adds quantization awareness to LoRA training by simulating quantization noise during fine-tuning.
Overview
  • Quantization-aware noise is injected into both the pretrained weights and the LoRA adapters during training.
  • This simulates inference-time quantization effects during training, allowing the adapters to compensate more effectively.
Architecture
\[W_{\text{finetuned}} = \text{QuantNoise}(W) + AB\]
  • \(\text{QuantNoise}(W)\): Simulates quantization-induced errors on the frozen weights.
  • \(AB\): LoRA component trained with awareness of quantization.
Key Features
  • Supports ultra-low-bit quantization (e.g., 3- or 4-bit).
  • Enables quantization-aware training (QAT) without modifying the original weight update path.
Pros
  • Improves robustness of LoRA adapters to quantization errors.
  • Achieves lower perplexity and better accuracy than standard LoRA or QLoRA in low-bit settings.
Cons
  • Adds complexity to training pipeline.
  • May require careful tuning of noise injection parameters.

Comparative Analysis

  • This taxonomy of LoRA variants shows a clear evolution: from memory-focused quantized adapters (QLoRA) to ultra-low-bit efficient models (LQ-LoRA), to quantization-aware robust fine-tuning (QA-LoRA). Choice of method depends on the target compression level, hardware constraints, and sensitivity to quantization-induced artifacts.
Feature LQ‑LoRA QLoRA QA‑LoRA
Quantization Level ≤ 3-bit (e.g., 2.75) 4-bit NF4 3–4-bit
Trainable Params Low-rank adapters only LoRA adapters only LoRA adapters with QAT
Quantized Weights Frozen, used as base Frozen, used as base Frozen + perturbed during training
Noise Handling None None Simulated quantization noise
Memory Efficiency ~27 GB for Llama 2 70B ~24 GB for LLaMA 65B Similar to QLoRA
Complexity Medium Low High
Best Use Case Ultra-compressed deployment General-purpose fine-tuning Robustness under low-bit QAT

Pros & Cons

  • Pros:

    • Parameter-efficient fine-tuning: Minimal new parameters needed (e.g., LoRA, QLoRA, QA-LoRA).
    • Quantization synergy: Works well with 4-bit quantization (QLoRA) or ultra-low-bit regimes (LQ-LoRA).
    • Quantization-aware robustness: QA-LoRA improves low-bit model accuracy via simulated noise.
    • Adaptable to distributed settings: LoRA-based updates are lightweight and communication-efficient.
    • Accuracy retention: Strong accuracy, even under aggressive quantization.
  • Cons:

    • Effectiveness may degrade if base weights are not sufficiently low-rank (task-dependent).
    • Combining quantization and adaptation (as in QA-LoRA or LQ-LoRA) introduces training complexity.
    • Requires careful tuning of rank, quantization scheme, and noise injection (QA-LoRA).
    • QLoRA assumes compatibility with NF4 quantization and specific tooling (bitsandbytes, Hugging Face PEFT).

Comparison & Use Cases

Use Case Suggested Strategy Benefit
Parameter-efficient tuning LoRA / QLoRA / ALoRA Reduces compute/memory footprint
Extreme quantization LQ‑LoRA or Low-Rank Correction Sub‑3‑bit performance retention
Federated fine‑tuning Federated LoRA / QLoRA adapters Minimal communication cost
Quantization-aware training QA‑LoRA + QAT High fidelity under 3–4 bit settings
4-bit memory-efficient finetuning QLoRA Near full accuracy with 4-bit NF4
Robust training under quantization noise QA‑LoRA Noise-aware adapters improve generalization

Key Takeaways

  • Low-rank techniques improve the efficiency of training and inference by decomposing full-rank weight matrices into two smaller matrices (e.g., \(A \in \mathbb{R}^{n \times r} \text{ and } B \in \mathbb{R}^{r \times m}, \text{ where } r \ll n, m\)). These factorizations can replace or augment full-weight updates, as seen in LoRA and its variants.

    • QLoRA combines low-rank adapters with 4-bit NF4 quantized base models, drastically reducing memory usage during training without sacrificing accuracy. It enables finetuning large LLMs (e.g., 65B+) on consumer hardware.
    • QA-LoRA extends this further by injecting simulated quantization noise into the training process, making LoRA adapters inherently robust to downstream quantization.
    • LQ-LoRA targets extremely low-bit regimes (e.g., sub-3-bit) by jointly optimizing low-rank corrections and quantized base weights.
  • Overall, low-rank decomposition plays a central role in enabling quantization-aware fine-tuning pipelines, federated adaptation, and cross-device deployment, all while maintaining high performance and parameter efficiency.

Further Reading

References

Citation

If you found our work useful, please cite it as:

@article{Chadha2020DistilledModelCompression,
  title   = {Model Compression},
  author  = {Chadha, Aman},
  journal = {Distilled AI},
  year    = {2020},
  note    = {\url{https://aman.ai}}
}