Primers • Model Compression
 Background
 Quantization
 Pruning
 DeepSpeed & ZeROOffload
 Knowledge distillation
 Conclusion
 Further Reading
 Citation
Background

This post covers model inference optimization or compression concepts using topics such as model quantization and binarization, pruning, and more researchoriented topics like knowledge distillation, as well as wellknownhacks.

Each year, larger and larger models are able to find methods for extracting signal from the noise in machine learning. In particular, language models get larger every day. These models are computationally expensive (in both runtime and memory), which can be both costly when served out to customers or too slow or large to function in edge environments like a phone.

Researchers and practitioners have come up with many methods for optimizing neural networks to run faster or with less memory usage. This article covers some of the stateoftheart methods.
Quantization
Background: Precision
 Before we talk about the quantization process, let’s learn about precision. From NVIDIA Blog: What’s the Difference Between Single, Double, Multi and MixedPrecision Computing?, per the IEEE 754 floating point specification, doubleprecision format uses 64 bits, singleprecision format uses 32 bits, while halfprecision is 16 bits.
Definition

Quantization generally refers to taking a model with parameters (weights, in all cases and activations, in most cases) trained at high precision (32 or 64 bits) and reducing the number of bits that each weight takes (for example down to 16, 8, or even fewer). In practice, this usually leads to a speedup of 24x (highest for nets with convolutions, based on experience).

Why does this work? It turns out that for deep networks to work, we don’t need highly precise values for the network’s weights. With proper hardware support, processing deep learning kernels (a fancy term for mathematical operations) using fewer bits can be faster and more memory efficient simply because there’s fewer bits to compute (
torch.qint8
is 8 bits, andtorch.float32
is 32 bits, so 4x smaller).
Downsides: Depending on the level of quantization attempted, you might find that an operation you want (for example, a particular convolutional op or even something as simple as transpose) might not be implemented. Of course, as with all methods, you might also find that accuracy drops off too much to be useful.
 From the TensorFlow docs:
We generally recommend 16bit floats for GPU acceleration and 8bit integer for CPU execution.
Quantization with PyTorch
 PyTorch has support for special quantized tensors, which in their case corresponds to storing data in 8 or 16 bits. It’s important to understand one specific detail about how this works. If your network has a special structure that means that at some point all of the outputs are between
0
and1
(e.g., from a sigmoid), then you might be able to choose a better, more specific quantization. This means that quantization needs to collect some data about how your network runs on representative inputs. In particular, most quantization happens via a method likeround(x * scalar)
, wherescalar
is a learned parameter (akin to BatchNorm).
Support for some of these operations are in libraries that are “external” to PyTorch (but loaded as required). Think of this like BLAS or MKL for quantized operations. FBGEMM is an implementation for servers, and QNNPACK is an implementation for mobile devices (now inside PyTorch proper).
 Quantization occasionally has gotchas  accumulating in higher precision data types is often more stable than using lower precision values, especially if the input data has deep levels of an exponent. Picking the right precision for each operation can be nonobvious, so PyTorch has a
torch.cuda.amp
package to help you automatically cast different parts of your network to half precision (torch.float16
) where it’s possible. If you want to do this manually, there’s some helpful tips on that page.
One of the very first things you can try is to take your existing model that’s all
torch.float32
, and run it usingtorch.cuda.amp
and see if it still runs with accuracy. Half precision support is still relatively sparse in consumer GPUs, but it works on the very common V100/P100/A100.

If you want more control or want to deploy to a nonCUDA environment, there are three levels of manual quantization (under the label “eager mode quantization”) that you can try, depending on why you’re trying to quantize and how much you’re willing to sweat:
 Dynamic quantization: weights quantized with activations read/stored in floating point and quantized for compute
 Static quantization: weights quantized, activations quantized, calibration required post training
 Static quantizationaware training: weights quantized, activations quantized, quantization numerics modeled during training

Please see PyTorch: Introduction to Quantization on Pytorch blog post for a more comprehensive overview of the tradeoffs between these quantization types.
Note that layer/operator coverage (Linear/Conv/RNN/LSTM/GRU/Attention) varies between dynamic and static quantization and is captured in the table below. Note that for FX quantization, the corresponding functionals are also supported.
Dynamic/Runtime Quantization
 The easiest method of quantization PyTorch supports is called dynamic quantization. This involves not just converting the weights to
int8
 as happens in all quantization variants  but also converting the activations toint8
on the fly, just before doing the computation (hence “dynamic”). The computations will thus be performed using efficientint8
matrix multiplication and convolution implementations, resulting in faster compute. However, the activations are read and written to memory in floating point format.  In other words, we store the weights of the network in the specified quantization, and then at runtime, activations are dynamically converted to the quantized format, combined with the (quantized) weights, then written in memory at full precision. Then the next layer quantizes those, combines with the next quantized weights, and so on. Why does this happen? My understanding is that
scalar
can be dynamically determined from the data, which means this is a datafree method.  How do we do this in PyTorch? PyTorch offers have a simple API for dynamic quantization in PyTorch.
torch.quantization.quantize_dynamic
takes in a model, as well as a couple other arguments, and produces a quantized model! Check out this endtoend tutorial illustrates this for a BERT model. As an example:
# quantize the LSTM and Linear parts of our network
# and use the torch.qint8 type to quantize
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
 There are many more knobs you can turn to make this better for your model. See more details in this blog post.
 See the documentation for the function here an endtoend example in our tutorials here and here.
PostTraining Static Quantization
 Runtime conversion to a full precision type and back is expensive. We can avoid that if we know what the distribution of activations will be, by say, recording real data flowing through the network.
 One can further improve the performance (latency) by converting networks to use both integer arithmetic and int8 memory accesses. Static quantization performs the additional step of first feeding batches of data through the network and computing the resulting distributions of the different activations (specifically, this is done by inserting “observer” modules at different points that record these distributions). This information is used to determine how specifically the different activations should be quantized at inference time (a simple technique would be to simply divide the entire range of activations into 256 levels, but we support more sophisticated methods as well).
 Importantly, this additional step allows us to pass quantized values between operations instead of converting these values to floats  and then back to ints  between every operation, resulting in a significant speedup.

The following features are supported that allow users to optimize their static quantization:
 Observers: you can customize observer modules which specify how statistics are collected prior to quantization to try out more advanced methods to quantize your data.
 Observers are inserted using
torch.quantization.prepare
.
 Observers are inserted using
 Operator fusion: When you have access to data flowing through your network, PyTorch can also inspect your model and implement extra optimizations such as quantized operator fusion. You can fuse multiple operations into a single operation, saving on memory access while also improving the operation’s numerical accuracy.
 To fuse modules, use
torch.quantization.fuse_modules
.
 To fuse modules, use
 Perchannel quantization: we can independently quantize weights for each output channel in a convolution/linear layer, which can lead to higher accuracy with almost the same speed.
 Quantization itself is done using
torch.quantization.convert
.
 Quantization itself is done using
 Observers: you can customize observer modules which specify how statistics are collected prior to quantization to try out more advanced methods to quantize your data.
 Here’s an example of setting up the observers, running it with some data, and then exporting to a new statically quantized model:
# this is a default quantization config for mobilebased inference (ARM)
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
# or set quantization config for server (x86)
# model.qconfig = torch.quantization.get_default_config('fbgemm')
# this chain (conv + batchnorm + relu) is one of a few sequences
# that are supported by the model fuser
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
# insert observers
model_with_observers = torch.quantization.prepare(model_fused)
# calibrate the model and collect statistics
model_with_observers(example_batch)
# convert to quantized version
quantized_model = torch.quantization.convert(model_with_observers)
Static Quantizationaware Training (QAT)
 Quantizationaware training(QAT) is the third method, and the one that typically results in highest accuracy of these three. With QAT, all weights and activations are “fake quantized” during both the forward and backward passes of training: that is, float values are rounded to mimic int8 values, but all computations are still done with floating point numbers. Thus, all the weight adjustments during training are made while “aware” of the fact that the model will ultimately be quantized; after quantizing, therefore, this method usually yields higher accuracy than the other two methods.
 Put simply, if you tell the training method some fact about how the network is used, the network will adapt to this information. How does this work? During the forward and backward passes, the model’s activations are rounded to the picked quantization. This means the model gets gradients based on rounded values, which means it “adjusts” to its limited capacity.
Very importantly, however, the actual backprop (i.e., the gradient descent of the weights) happens in full precision.

torch.quantization.prepare_qat
inserts fake quantization modules to model quantization. Mimicking the static quantization API,torch.quantization.convert
actually quantizes the model once training is complete. 
For example, in the endtoend example, we load in a pretrained model as
qat_model
, then we simply perform quantizationaware training using:
# specify quantization config for QAT
qat_model.qconfig=torch.quantization.get_default_qat_qconfig('fbgemm')
# prepare QAT
torch.quantization.prepare_qat(qat_model, inplace=True)
# convert to quantized version, removing dropout, to check for accuracy on each
epochquantized_model=torch.quantization.convert(qat_model.eval(), inplace=False)
 Note: see the helpful tips under “Model Preparation for Quantization” here before using PyTorch quantization.
Device and Operator Support

Quantization support is restricted to a subset of available operators, depending on the method being used, for a list of supported operators, please see the documentation at here.

The set of available operators and the quantization numerics also depend on the backend being used to run quantized models. Currently quantized operators are supported only for CPU inference in the following backends: x86 and ARM. Both the quantization configuration (how tensors should be quantized and the quantized kernels (arithmetic with quantized tensors) are backend dependent. One can specify the backend by doing:
import torchbackend='fbgemm'
# 'fbgemm' for server, 'qnnpack' for mobile
my_model.qconfig = torch.quantization.get_default_qconfig(backend)
# prepare and convert model
# Set the backend on which the quantized kernels need to be run
torch.backends.quantized.engine=backend
 However, quantization aware training occurs in full floating point and can run on either GPU or CPU. Quantization aware training is typically only used in CNN models when post training static or dynamic quantization doesn’t yield sufficient accuracy. This can occur with models that are highly optimized to achieve small size (such as Mobilenet).
Integration in torchvision

PyTorch has also enabled quantization for some of the most popular models in torchvision: Googlenet, Inception, Resnet, ResNeXt, Mobilenet and Shufflenet. We have upstreamed these changes to torchvision in three forms:
 Pretrained quantized weights so that you can use them right away.
 Quantization ready model definitions so that you can do posttraining quantization or quantization aware training.
 A script for doing quantization aware training — which is available for any of these model though, as you will learn below, we only found it necessary for achieving accuracy with Mobilenet.
 We also have a tutorial showing how you can do transfer learning with quantization using one of the torchvision models.
Choosing an approach

The choice of which scheme to use depends on multiple factors:
 Model/Target requirements: Some models might be sensitive to quantization, requiring quantization aware training.
 Operator/Backend support: Some backends require fully quantized operators.

Currently, operator coverage in PyTorch is limited and may restrict the choices listed in the table below. The table below from PyTorch: Introduction to Quantization on PyTorch provides a guideline.
Performance Results
 Quantization provides a 4x reduction in the model size and a speedup of 2x to 3x compared to floating point implementations depending on the hardware platform and the model being benchmarked. The table below from PyTorch: Introduction to Quantization on PyTorch offers some sample results:
Accuracy results
 The tables below from PyTorch: Introduction to Quantization on PyTorch compares the accuracy of static quantized models with the floating point models on Imagenet. For dynamic quantization, we compared the F1 score of BERT on the GLUE benchmark for MRPC.
Computer Vision Model accuracy
Speech and NLP Model accuracy
Conclusion
 To get started on quantizing your models in PyTorch, start with the tutorials on the PyTorch website.
 If you are working with sequence data, start with…
 If you are working with image data then we recommend starting with the transfer learning with quantization tutorial. Then you can explore static post training quantization.
 If you find that the accuracy drop with post training quantization is too high, then try quantization aware training.
Quantization in other frameworks: TensorFlow and CoreML

PyTorchbased quantization might not necessarily work in other production environments. In particular, when converting to Apple’s CoreML format, you need to just use their quantization (which might be limited to just 16bit quantization). When using edge devices, be careful to check that quantization is possible (in Apple’s case the hardware is already computing everything in
fp16
on GPU, so you only save possibly the memory of the network’s weights). 
TensorFlow has a similar set of steps as above, though the examples are focused on TFLite. Essentially, static and dynamic quantization are explained in the Posttraining quantization page, and there’s a QAT page. The tradeoffs appear to be very similar, though there’s always some feature mismatch between PyTorch and TF.
How far can we go?
 Apparently down to 1 bit! There have been several attempts over the years to create binary neural networks if you want the most extreme version of the accuracy vs speed tradeoff. For the most part, these are still research projects rather than usable ideas, though XNORNet++ seems to have been implemented in PyTorch.
Further Reading
 PyTorch official documentation: Introduction to Quantization on PyTorch
 PyTorch official documentation: Advanced Quantization in PyTorch
 PyTorch official documentation: Quantization
 CoreML Tools documentation: Quantization
Pruning

Pruning is removing some weights (i.e., connections) or entire neurons from a neural network after or during training. In practice we can often remove 90% of the parameters in large deep neural networks without significantly affecting model performance.

Why does this work? Let’s imagine that your model is a fully connected neural network with just one hidden layer, such that the input is size 1024, the hidden size is 100, and the output is 20 dimensions. Then the number of parameters (without bias) is 104400. If there’s a neuron in the hidden layer that never fires (or is ignored downstream) then removing it from the network saves 1044 parameters. Why not just train the smaller network right away? The most compelling explanation is something called the lottery ticket hypothesis:
Any large network that trains successfully contains a subnetwork that is initialized such that  when trained in isolation  it can match the accuracy of the original network in at most the same number of training iterations.
Structured vs. Unstructured pruning

Removing neurons or choosing a subnetwork is what people consider structured pruning. However, a lot of methods (including TensorFlow’s
tensorflow_model_optimization
toolkit at this time and PyTorch’storch.nn.utils.prune
) are focused on sparsifying model weights so that they are more compressible (usually called unstructured pruning). This means the matrices are the same size, but some values are set to 0. This can save disk space using compression algorithms (such as runlength encoding or bytepair encoding). When sparse model support fully lands in the various frameworks (i.e., you can multiply a sparse vector and a sparse matrix faster than the dense ones) you might be able to speed up inference as well. 
For that reason, unstructured pruning (currently) doesn’t seem that useful, but essentially you can prune during or after training, and you pick a certain target sparsity (e.g., 80% of the weights of your network will be zeroed out). However, there’s a lot of confusion in this area which makes it hard to recommend anything. TensorFlow has a a few guides on pruning both during and after training and PyTorch has a tutorial on pruning using some set of heuristics after training.

In the space of structured pruning, there’s still active research and no clear API. We can pick a metric to compute a relevance score for each neuron, and then remove the ones that have the least information content. Metrics that might be useful here are the Shapley value, a Taylor approximation of the loss functions sensitivity to a neuron’s activation, or even a random neuron. Before you begin, check out PyTorch: Pruning Tutorial. The TorchPruner library implements some of these automatically for
nn.Linear
and convolutions (nn.Conv1D, nn.Conv2D, etc) modules. Another library TorchPruning has support for a few more operations. One of the most wellknown older works in this area prunes filters from a convnet using the L1 norm of the filter’s weights. However, this is still an active area of research.
Fine tuning
 In both cases, it’s standard to retrain the network after applying the pruning. Currently, the best method is basically to reset the learning rate (learning rate rewinding) and start retraining the network. If you’d like, you can use weight rewinding, which is resetting the weights for the unpruned parts of the network to their value earlier in training (e.g., 1/3 trained weights). My intuition on this is that it’s essentially training the lottery ticket subnetwork now that we’ve identified it.
Overall, a practitioner who is really interested in trying this should start with TorchPruner or TorchPruning and then try fine tuning the resulting network with learning rate rewinding. However, for most architectures (including ResNets because of skip connections) it’ll be pretty nonobvious how to trim the rest of the network around this.
DeepSpeed & ZeROOffload
 Essentially, DeepSpeed is a library that helps train large to extremely large models (e.g., 1bn+ parameters) faster and using less GPU memory. This works by exploiting smart parallelism and better caching. It comes in the form of an extension to PyTorch.
Knowledge distillation

Knowledge distillation is a method for creating smaller and more efficient models from large models. In NLP this has also been referred to as teacherstudent methods, because the large model trains the student model. The reference work in this area is Hinton et al., 2015.

In practice, suppose we have a classification task. Suppose our smaller student model is \(f_{\theta}\), where $\theta$ is the set of parameters. We take either a large model or an ensemble of models (possibly even the same model trained with different initializations), and call it $F$ (we won’t worry about its parameters). Then we train the student network with the following loss:
\(\mathcal{L}=\sum_{i=1}^{n} \mathrm{KL}\left(F\left(x_{i}\right), f_{\theta}\left(x_{i}\right)\right)\)  where \(F\left(x_{i}\right)\) is the probability distribution over the labels created by passing example \(x_{i}\) through the network.
 If you want, you can add in the regular cross entropy loss using the proper labels (by passing in the onehot ground truth distribution to the student as well):

Note that this second term is just the KL divergence from the “true” distribution (i.e., the onehot distribution from the labels) to the student model, since is onehot.

Why does this work? There’s no consensus best opinion in the field. The most compelling explanation is that distillation is a form of rough data augmentation. This paper is recommended to understand why: Towards Understanding Ensemble, Knowledge Distillation and SelfDistillation in Deep Learning, which is focused on the idea of multiple views. Here’s some thought experiments that might help explain what’s happening underthehood.
Distillation thought experiment

Let’s say that we have a large teacher model that is trained to classify images (e.g., CIFAR100). This model implicitly has a bunch of “featuredetectors” builtin, e.g., a set of convolutional filters that fire when pointy ears are seen, which increase the probability of a label like “cat”. Let’s say that there’s a training image of a Batman mask, labeled “mask”. The teacher model’s pointy ears filters might still fire, telling us that the model thinks that this looks 10% like a cat.

When the student model is trained to match the probability distribution of the teacher, because the distribution is 0.1 cat, it will still get a small signal that this image is catlike, which might help the student model recognize cats better than it could otherwise. If the student model was trained on just the true labels, it would have no idea that this Batman mask looks a bit like a cat. This logic also supports why the student can outperform the teacher in some cases.
Ensembling thought experiment

A similar, but slightly different idea explains why ensembles of models (even the same architecture) might work well. Let’s say there’s 3 pictures of a cat in a dataset we’re using for image classification. Let’s say that image 1 has a cat with feature A (e.g., pointed ears), image 2 has feature B (e.g., whiskers), and image 3 has both A and B.

Then, let’s say the neural network learns feature A (e.g., by seeing image 1). When it sees image 3, that set of convolution filters will fire, and so the image will be correctly classified. So, there’ll be no gradient that tunes the net to recognize feature B, even though a good net would learn that.

Once a neural network has become good enough, its signal from some data points decreases.
Distillation in practice

Knowledge distillation is a very deep and wide research area, touching adversarial attacks, knowledge transfer, and privacy.

In practice, the method I’ve described above is called responsebased distillation. There are also other forms of distillation, including featurebased and relationbased knowledge distillation, which are entire subfields based on what parts (or computations from) the student and teacher model we should tie together.

Furthermore, there’s a division between offline distillation (i.e., train the student after the teacher), online distillation (train the student and teacher together), and selfdistillation (where the teacher model has the same architecture as the student). Together this makes it difficult to track distillation in practice; a set of adhoc modelspecific techniques might be the best general recommendation.

In fact, Cho & Hariharan, 2019 found that when the student model’s capacity is too low, using knowledge distillation will actually adversely affect training. They found that knowledge distillation papers rarely use ImageNet and so often don’t work well on difficult problems. Perplexingly, that paper and Mirzadeh et al., 2019 found that better teacher models don’t always mean better distillation, and the farther the student and teacher model’s capacities are, the less effective distillation was. You can find a recent investigation in Tang et al., 2021.

All in all, distillation is relatively difficult compared to quantization and pruning. You might be able to get some free performance points by training a student with a slightly smaller capacity and then using vanilla responsebased offline distillation.
Distillation as semisupervised learning
 You can train a teacher model, which is a much more powerful model than the student, with a small set of labeled data. Next, use the teacher to automatically label unannotated data, which can be used to train a leaner, more efficient “student” network.
 For e.g., Lessons from building acoustic models with a million hours of speech by Parthasarathi and Strom (2019), used a small set of annotated data (green) to train a powerful but impractically slow “teacher” network to convert frequencylevel descriptions of audio data into sequences of phones. The teacher, in turn, labeled a much larger set of unannotated data (red). They then used both datasets to train a leaner, more efficient “student” model.
Conclusion

Deep learning researchers have spent a lot of time distilling large models using modelspecific methods, and if you need to gain some performance, you might be able to find a pretrained distilled version of the large model you’re currently using. For example, in NLP, HuggingFace makes it easy to access both DistilBert and TinyBert. In computer vision, Facebook Research’s d2go has a bunch of pretrained mobileready models, and they’ve specialized some distillation methods in DeiT.

WellRead Students Learn Better: On the Importance of Pretraining Compact Models makes a recommendation (with high quality ablation experiments) that for training BERT architectures, the best approach is:
 Pretrain a compact model architecture on the masked language model (MLM) objective developed by the original BERT papers (Devlin et al., 2018).
 Take a large taskspecific teacher model (e.g., if the task is NLI, the output is a distribution over the 3 classes (entailment, contradiction, neutral)), and perform basic responsebased offline distillation on the pretrained compact model from step 1.
 Finally, if required, finetune the compact model from step 2 on the taskspecific data (e.g., if the task is NER, train over the CoNLL 2003 dataset).

One of the best advantages of this method (which they call Pretrained Distillation (PD)) is that it’s architectureagnostic. If you are going to use a compact NLP model in practice, it’s worth skimming the paper, especially section 6.
Further Reading
 PyTorch: Quantization
 PyTorch: Introduction to Quantization on PyTorch
 Deep Learning Model Compression by Rachit Singh
 PyTorch: Pruning Tutorial
Citation
If you found our work useful, please cite it as:
@article{Chadha2020DistilledModelCompression,
title = {Model Compression},
author = {Chadha, Aman},
journal = {Distilled AI},
year = {2020},
note = {\url{https://aman.ai}}
}