• This post covers model inference optimization or compression concepts using topics such as model quantization and binarization, pruning, and more research-oriented topics like knowledge distillation, as well as well-known-hacks.

  • Each year, larger and larger models are able to find methods for extracting signal from the noise in machine learning. In particular, language models get larger every day. These models are computationally expensive (in both runtime and memory), which can be both costly when served out to customers or too slow or large to function in edge environments like a phone.

  • Researchers and practitioners have come up with many methods for optimizing neural networks to run faster or with less memory usage. This article covers some of the state-of-the-art methods.


Background: Precision


  • Quantization generally refers to taking a model with parameters (weights, in all cases and activations, in most cases) trained at high precision (32 or 64 bits) and reducing the number of bits that each weight takes (for example down to 16, 8, or even fewer). In practice, this usually leads to a speedup of 2-4x (highest for nets with convolutions, based on experience).

  • Why does this work? It turns out that for deep networks to work, we don’t need highly precise values for the network’s weights. With proper hardware support, processing deep learning kernels (a fancy term for mathematical operations) using fewer bits can be faster and more memory efficient simply because there’s fewer bits to compute (torch.qint8 is 8 bits, and torch.float32 is 32 bits, so 4x smaller).

Downsides: Depending on the level of quantization attempted, you might find that an operation you want (for example, a particular convolutional op or even something as simple as transpose) might not be implemented. Of course, as with all methods, you might also find that accuracy drops off too much to be useful.

  • From the TensorFlow docs:

We generally recommend 16-bit floats for GPU acceleration and 8-bit integer for CPU execution.

Quantization with PyTorch

  • PyTorch has support for special quantized tensors, which in their case corresponds to storing data in 8 or 16 bits. It’s important to understand one specific detail about how this works. If your network has a special structure that means that at some point all of the outputs are between 0 and 1 (e.g., from a sigmoid), then you might be able to choose a better, more specific quantization. This means that quantization needs to collect some data about how your network runs on representative inputs. In particular, most quantization happens via a method like round(x * scalar), where scalar is a learned parameter (akin to BatchNorm).

Support for some of these operations are in libraries that are “external” to PyTorch (but loaded as required). Think of this like BLAS or MKL for quantized operations. FBGEMM is an implementation for servers, and QNNPACK is an implementation for mobile devices (now inside PyTorch proper).

  • Quantization occasionally has gotchas - accumulating in higher precision data types is often more stable than using lower precision values, especially if the input data has deep levels of an exponent. Picking the right precision for each operation can be nonobvious, so PyTorch has a torch.cuda.amp package to help you automatically cast different parts of your network to half precision (torch.float16) where it’s possible. If you want to do this manually, there’s some helpful tips on that page.

One of the very first things you can try is to take your existing model that’s all torch.float32, and run it using torch.cuda.amp and see if it still runs with accuracy. Half precision support is still relatively sparse in consumer GPUs, but it works on the very common V100/P100/A100.

  • If you want more control or want to deploy to a non-CUDA environment, there are three levels of manual quantization (under the label “eager mode quantization”) that you can try, depending on why you’re trying to quantize and how much you’re willing to sweat:

    • Dynamic quantization: weights quantized with activations read/stored in floating point and quantized for compute
    • Static quantization: weights quantized, activations quantized, calibration required post training
    • Static quantization-aware training: weights quantized, activations quantized, quantization numerics modeled during training
  • Please see PyTorch: Introduction to Quantization on Pytorch blog post for a more comprehensive overview of the tradeoffs between these quantization types.

Note that layer/operator coverage (Linear/Conv/RNN/LSTM/GRU/Attention) varies between dynamic and static quantization and is captured in the table below. Note that for FX quantization, the corresponding functionals are also supported.

Dynamic/Runtime Quantization

  • The easiest method of quantization PyTorch supports is called dynamic quantization. This involves not just converting the weights to int8 - as happens in all quantization variants - but also converting the activations to int8 on the fly, just before doing the computation (hence “dynamic”). The computations will thus be performed using efficient int8 matrix multiplication and convolution implementations, resulting in faster compute. However, the activations are read and written to memory in floating point format.
  • In other words, we store the weights of the network in the specified quantization, and then at runtime, activations are dynamically converted to the quantized format, combined with the (quantized) weights, then written in memory at full precision. Then the next layer quantizes those, combines with the next quantized weights, and so on. Why does this happen? My understanding is that scalar can be dynamically determined from the data, which means this is a data-free method.
  • How do we do this in PyTorch? PyTorch offers have a simple API for dynamic quantization in PyTorch. torch.quantization.quantize_dynamic takes in a model, as well as a couple other arguments, and produces a quantized model! Check out this end-to-end tutorial illustrates this for a BERT model. As an example:
# quantize the LSTM and Linear parts of our network
# and use the torch.qint8 type to quantize
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
  • There are many more knobs you can turn to make this better for your model. See more details in this blog post.
  • See the documentation for the function here an end-to-end example in our tutorials here and here.

Post-Training Static Quantization

  • Runtime conversion to a full precision type and back is expensive. We can avoid that if we know what the distribution of activations will be, by say, recording real data flowing through the network.
  • One can further improve the performance (latency) by converting networks to use both integer arithmetic and int8 memory accesses. Static quantization performs the additional step of first feeding batches of data through the network and computing the resulting distributions of the different activations (specifically, this is done by inserting “observer” modules at different points that record these distributions). This information is used to determine how specifically the different activations should be quantized at inference time (a simple technique would be to simply divide the entire range of activations into 256 levels, but we support more sophisticated methods as well).
  • Importantly, this additional step allows us to pass quantized values between operations instead of converting these values to floats - and then back to ints - between every operation, resulting in a significant speed-up.
  • The following features are supported that allow users to optimize their static quantization:

    • Observers: you can customize observer modules which specify how statistics are collected prior to quantization to try out more advanced methods to quantize your data.
      • Observers are inserted using torch.quantization.prepare.
    • Operator fusion: When you have access to data flowing through your network, PyTorch can also inspect your model and implement extra optimizations such as quantized operator fusion. You can fuse multiple operations into a single operation, saving on memory access while also improving the operation’s numerical accuracy.
      • To fuse modules, use torch.quantization.fuse_modules.
    • Per-channel quantization: we can independently quantize weights for each output channel in a convolution/linear layer, which can lead to higher accuracy with almost the same speed.
      • Quantization itself is done using torch.quantization.convert.
  • Here’s an example of setting up the observers, running it with some data, and then exporting to a new statically quantized model:
# this is a default quantization config for mobile-based inference (ARM)
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
# or set quantization config for server (x86)
# model.qconfig = torch.quantization.get_default_config('fbgemm')

# this chain (conv + batchnorm + relu) is one of a few sequences 
# that are supported by the model fuser 
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])

# insert observers
model_with_observers = torch.quantization.prepare(model_fused)

# calibrate the model and collect statistics

# convert to quantized version
quantized_model = torch.quantization.convert(model_with_observers)

Static Quantization-aware Training (QAT)

  • Quantization-aware training(QAT) is the third method, and the one that typically results in highest accuracy of these three. With QAT, all weights and activations are “fake quantized” during both the forward and backward passes of training: that is, float values are rounded to mimic int8 values, but all computations are still done with floating point numbers. Thus, all the weight adjustments during training are made while “aware” of the fact that the model will ultimately be quantized; after quantizing, therefore, this method usually yields higher accuracy than the other two methods.
  • Put simply, if you tell the training method some fact about how the network is used, the network will adapt to this information. How does this work? During the forward and backward passes, the model’s activations are rounded to the picked quantization. This means the model gets gradients based on rounded values, which means it “adjusts” to its limited capacity.

Very importantly, however, the actual backprop (i.e., the gradient descent of the weights) happens in full precision.

  • torch.quantization.prepare_qat inserts fake quantization modules to model quantization. Mimicking the static quantization API, torch.quantization.convert actually quantizes the model once training is complete.

  • For example, in the end-to-end example, we load in a pre-trained model as qat_model, then we simply perform quantization-aware training using:

# specify quantization config for QAT

# prepare QAT
torch.quantization.prepare_qat(qat_model, inplace=True)

# convert to quantized version, removing dropout, to check for accuracy on each
epochquantized_model=torch.quantization.convert(qat_model.eval(), inplace=False)
  • Note: see the helpful tips under “Model Preparation for Quantization” here before using PyTorch quantization.

Device and Operator Support

  • Quantization support is restricted to a subset of available operators, depending on the method being used, for a list of supported operators, please see the documentation at here.

  • The set of available operators and the quantization numerics also depend on the backend being used to run quantized models. Currently quantized operators are supported only for CPU inference in the following backends: x86 and ARM. Both the quantization configuration (how tensors should be quantized and the quantized kernels (arithmetic with quantized tensors) are backend dependent. One can specify the backend by doing:

import torchbackend='fbgemm'

# 'fbgemm' for server, 'qnnpack' for mobile
my_model.qconfig = torch.quantization.get_default_qconfig(backend)

# prepare and convert model
# Set the backend on which the quantized kernels need to be run
  • However, quantization aware training occurs in full floating point and can run on either GPU or CPU. Quantization aware training is typically only used in CNN models when post training static or dynamic quantization doesn’t yield sufficient accuracy. This can occur with models that are highly optimized to achieve small size (such as Mobilenet).

Integration in torchvision

  • PyTorch has also enabled quantization for some of the most popular models in torchvision: Googlenet, Inception, Resnet, ResNeXt, Mobilenet and Shufflenet. We have upstreamed these changes to torchvision in three forms:

    1. Pre-trained quantized weights so that you can use them right away.
    2. Quantization ready model definitions so that you can do post-training quantization or quantization aware training.
    3. A script for doing quantization aware training — which is available for any of these model though, as you will learn below, we only found it necessary for achieving accuracy with Mobilenet.
    4. We also have a tutorial showing how you can do transfer learning with quantization using one of the torchvision models.

Choosing an approach

  • The choice of which scheme to use depends on multiple factors:

    • Model/Target requirements: Some models might be sensitive to quantization, requiring quantization aware training.
    • Operator/Backend support: Some backends require fully quantized operators.
  • Currently, operator coverage in PyTorch is limited and may restrict the choices listed in the table below. The table below from PyTorch: Introduction to Quantization on PyTorch provides a guideline.

Performance Results

  • Quantization provides a 4x reduction in the model size and a speedup of 2x to 3x compared to floating point implementations depending on the hardware platform and the model being benchmarked. The table below from PyTorch: Introduction to Quantization on PyTorch offers some sample results:

Accuracy results

  • The tables below from PyTorch: Introduction to Quantization on PyTorch compares the accuracy of static quantized models with the floating point models on Imagenet. For dynamic quantization, we compared the F1 score of BERT on the GLUE benchmark for MRPC.

Computer Vision Model accuracy

Speech and NLP Model accuracy


Quantization in other frameworks: TensorFlow and CoreML

  • PyTorch-based quantization might not necessarily work in other production environments. In particular, when converting to Apple’s CoreML format, you need to just use their quantization (which might be limited to just 16-bit quantization). When using edge devices, be careful to check that quantization is possible (in Apple’s case the hardware is already computing everything in fp16 on GPU, so you only save possibly the memory of the network’s weights).

  • TensorFlow has a similar set of steps as above, though the examples are focused on TFLite. Essentially, static and dynamic quantization are explained in the Post-training quantization page, and there’s a QAT page. The tradeoffs appear to be very similar, though there’s always some feature mismatch between PyTorch and TF.

How far can we go?

  • Apparently down to 1 bit! There have been several attempts over the years to create binary neural networks if you want the most extreme version of the accuracy vs speed tradeoff. For the most part, these are still research projects rather than usable ideas, though XNOR-Net++ seems to have been implemented in PyTorch.

Further Reading


  • Pruning is removing some weights (i.e., connections) or entire neurons from a neural network after or during training. In practice we can often remove 90% of the parameters in large deep neural networks without significantly affecting model performance.

  • Why does this work? Let’s imagine that your model is a fully connected neural network with just one hidden layer, such that the input is size 1024, the hidden size is 100, and the output is 20 dimensions. Then the number of parameters (without bias) is 104400. If there’s a neuron in the hidden layer that never fires (or is ignored downstream) then removing it from the network saves 1044 parameters. Why not just train the smaller network right away? The most compelling explanation is something called the lottery ticket hypothesis:

Any large network that trains successfully contains a subnetwork that is initialized such that - when trained in isolation - it can match the accuracy of the original network in at most the same number of training iterations.

Structured vs. Unstructured pruning

  • Removing neurons or choosing a subnetwork is what people consider structured pruning. However, a lot of methods (including TensorFlow’s tensorflow_model_optimization toolkit at this time and PyTorch’s torch.nn.utils.prune) are focused on sparsifying model weights so that they are more compressible (usually called unstructured pruning). This means the matrices are the same size, but some values are set to 0. This can save disk space using compression algorithms (such as run-length encoding or byte-pair encoding). When sparse model support fully lands in the various frameworks (i.e., you can multiply a sparse vector and a sparse matrix faster than the dense ones) you might be able to speed up inference as well.

  • For that reason, unstructured pruning (currently) doesn’t seem that useful, but essentially you can prune during or after training, and you pick a certain target sparsity (e.g., 80% of the weights of your network will be zeroed out). However, there’s a lot of confusion in this area which makes it hard to recommend anything. TensorFlow has a a few guides on pruning both during and after training and PyTorch has a tutorial on pruning using some set of heuristics after training.

  • In the space of structured pruning, there’s still active research and no clear API. We can pick a metric to compute a relevance score for each neuron, and then remove the ones that have the least information content. Metrics that might be useful here are the Shapley value, a Taylor approximation of the loss functions sensitivity to a neuron’s activation, or even a random neuron. Before you begin, check out PyTorch: Pruning Tutorial. The TorchPruner library implements some of these automatically for nn.Linear and convolutions (nn.Conv1D, nn.Conv2D, etc) modules. Another library Torch-Pruning has support for a few more operations. One of the most well-known older works in this area prunes filters from a convnet using the L1 norm of the filter’s weights. However, this is still an active area of research.

Fine tuning

  • In both cases, it’s standard to retrain the network after applying the pruning. Currently, the best method is basically to reset the learning rate (learning rate rewinding) and start retraining the network. If you’d like, you can use weight rewinding, which is resetting the weights for the unpruned parts of the network to their value earlier in training (e.g., 1/3 trained weights). My intuition on this is that it’s essentially training the lottery ticket subnetwork now that we’ve identified it.

Overall, a practitioner who is really interested in trying this should start with TorchPruner or Torch-Pruning and then try fine tuning the resulting network with learning rate rewinding. However, for most architectures (including ResNets because of skip connections) it’ll be pretty non-obvious how to trim the rest of the network around this.

DeepSpeed & ZeRO-Offload

  • Essentially, DeepSpeed is a library that helps train large to extremely large models (e.g., 1bn+ parameters) faster and using less GPU memory. This works by exploiting smart parallelism and better caching. It comes in the form of an extension to PyTorch.

Knowledge distillation

  • Knowledge distillation is a method for creating smaller and more efficient models from large models. In NLP this has also been referred to as teacher-student methods, because the large model trains the student model. The reference work in this area is Hinton et al., 2015.

  • In practice, suppose we have a classification task. Suppose our smaller student model is \(f_{\theta}\), where $\theta$ is the set of parameters. We take either a large model or an ensemble of models (possibly even the same model trained with different initializations), and call it $F$ (we won’t worry about its parameters). Then we train the student network with the following loss:

\(\mathcal{L}=\sum_{i=1}^{n} \mathrm{KL}\left(F\left(x_{i}\right), f_{\theta}\left(x_{i}\right)\right)\) - where \(F\left(x_{i}\right)\) is the probability distribution over the labels created by passing example \(x_{i}\) through the network.

  • If you want, you can add in the regular cross entropy loss using the proper labels (by passing in the one-hot ground truth distribution to the student as well):
\[\mathcal{L}=\sum_{i=1}^{n}\left(\mathrm{KL}\left(F\left(x_{i}\right), f_{\theta}\left(x_{i}\right)\right)-\beta \cdot \sum_{k=1}^{K} y_{i}[k] \log f_{\theta}\left(x_{i}\right)[k]\right)\]
  • Note that this second term is just the KL divergence from the “true” distribution (i.e., the one-hot distribution from the labels) to the student model, since is one-hot.

  • Why does this work? There’s no consensus best opinion in the field. The most compelling explanation is that distillation is a form of rough data augmentation. This paper is recommended to understand why: Towards Understanding Ensemble, Knowledge Distillation and Self-Distillation in Deep Learning, which is focused on the idea of multiple views. Here’s some thought experiments that might help explain what’s happening under-the-hood.

Distillation thought experiment

  • Let’s say that we have a large teacher model that is trained to classify images (e.g., CIFAR-100). This model implicitly has a bunch of “feature-detectors” built-in, e.g., a set of convolutional filters that fire when pointy ears are seen, which increase the probability of a label like “cat”. Let’s say that there’s a training image of a Batman mask, labeled “mask”. The teacher model’s pointy ears filters might still fire, telling us that the model thinks that this looks 10% like a cat.

  • When the student model is trained to match the probability distribution of the teacher, because the distribution is 0.1 cat, it will still get a small signal that this image is catlike, which might help the student model recognize cats better than it could otherwise. If the student model was trained on just the true labels, it would have no idea that this Batman mask looks a bit like a cat. This logic also supports why the student can out-perform the teacher in some cases.

Ensembling thought experiment

  • A similar, but slightly different idea explains why ensembles of models (even the same architecture) might work well. Let’s say there’s 3 pictures of a cat in a dataset we’re using for image classification. Let’s say that image 1 has a cat with feature A (e.g., pointed ears), image 2 has feature B (e.g., whiskers), and image 3 has both A and B.

  • Then, let’s say the neural network learns feature A (e.g., by seeing image 1). When it sees image 3, that set of convolution filters will fire, and so the image will be correctly classified. So, there’ll be no gradient that tunes the net to recognize feature B, even though a good net would learn that.

  • Once a neural network has become good enough, its signal from some data points decreases.

Distillation in practice

  • Knowledge distillation is a very deep and wide research area, touching adversarial attacks, knowledge transfer, and privacy.

  • In practice, the method I’ve described above is called response-based distillation. There are also other forms of distillation, including feature-based and relation-based knowledge distillation, which are entire subfields based on what parts (or computations from) the student and teacher model we should tie together.

  • Furthermore, there’s a division between offline distillation (i.e., train the student after the teacher), online distillation (train the student and teacher together), and self-distillation (where the teacher model has the same architecture as the student). Together this makes it difficult to track distillation in practice; a set of adhoc model-specific techniques might be the best general recommendation.

  • In fact, Cho & Hariharan, 2019 found that when the student model’s capacity is too low, using knowledge distillation will actually adversely affect training. They found that knowledge distillation papers rarely use ImageNet and so often don’t work well on difficult problems. Perplexingly, that paper and Mirzadeh et al., 2019 found that better teacher models don’t always mean better distillation, and the farther the student and teacher model’s capacities are, the less effective distillation was. You can find a recent investigation in Tang et al., 2021.

  • All in all, distillation is relatively difficult compared to quantization and pruning. You might be able to get some free performance points by training a student with a slightly smaller capacity and then using vanilla response-based offline distillation.

Distillation as semi-supervised learning

  • You can train a teacher model, which is a much more powerful model than the student, with a small set of labeled data. Next, use the teacher to automatically label unannotated data, which can be used to train a leaner, more efficient “student” network.
  • For e.g., Lessons from building acoustic models with a million hours of speech by Parthasarathi and Strom (2019), used a small set of annotated data (green) to train a powerful but impractically slow “teacher” network to convert frequency-level descriptions of audio data into sequences of phones. The teacher, in turn, labeled a much larger set of unannotated data (red). They then used both datasets to train a leaner, more efficient “student” model.


  • Deep learning researchers have spent a lot of time distilling large models using model-specific methods, and if you need to gain some performance, you might be able to find a pre-trained distilled version of the large model you’re currently using. For example, in NLP, HuggingFace makes it easy to access both DistilBert and TinyBert. In computer vision, Facebook Research’s d2go has a bunch of pretrained mobile-ready models, and they’ve specialized some distillation methods in DeiT.

  • Well-Read Students Learn Better: On the Importance of Pre-training Compact Models makes a recommendation (with high quality ablation experiments) that for training BERT architectures, the best approach is:

    1. Pre-train a compact model architecture on the masked language model (MLM) objective developed by the original BERT papers (Devlin et al., 2018).
    2. Take a large task-specific teacher model (e.g., if the task is NLI, the output is a distribution over the 3 classes (entailment, contradiction, neutral)), and perform basic response-based offline distillation on the pre-trained compact model from step 1.
    3. Finally, if required, fine-tune the compact model from step 2 on the task-specific data (e.g., if the task is NER, train over the CoNLL 2003 dataset).
  • One of the best advantages of this method (which they call Pre-trained Distillation (PD)) is that it’s architecture-agnostic. If you are going to use a compact NLP model in practice, it’s worth skimming the paper, especially section 6.

Further Reading


If you found our work useful, please cite it as:

  title   = {Model Compression},
  author  = {Chadha, Aman},
  journal = {Distilled AI},
  year    = {2020},
  note    = {\url{}}