Background: Representation Learning for NLP

  • At a high level, all neural network architectures build representations of input data as vectors/embeddings, which encode useful statistical and semantic information about the data. These latent or hidden representations can then be used for performing something useful, such as classifying an image or translating a sentence. The neural network learns to build better-and-better representations by receiving feedback, usually via error/loss functions.
  • For Natural Language Processing (NLP), conventionally, Recurrent Neural Networks (RNNs) build representations of each word in a sentence in a sequential manner, i.e., one word at a time. Intuitively, we can imagine an RNN layer as a conveyor belt, with the words being processed on it autoregressively from left to right. In the end, we get a hidden feature for each word in the sentence, which we pass to the next RNN layer or use for our NLP tasks of choice.
  • Chris Olah’s legendary blog for recaps on RNNs and representation learning for NLP is highly recommend to develop a background in this area.

Enter the Transformer

  • Initially introduced for machine translation, Transformers have gradually replaced RNNs in mainstream NLP. The architecture takes a fresh approach to representation learning: Doing away with recurrence entirely, Transformers build features of each word using an attention mechanism to figure out how important all the other words in the sentence are w.r.t. to the aforementioned word. Knowing this, the word’s updated features are simply the sum of linear transformations of the features of all the words, weighted by their importance.
  • Back in 2017, this idea sounded very radical, because the NLP community was so used to the sequential – one-word-at-a-time – style of processing text with RNNs. The title of the paper probably added fuel to the fire! For a recap, Yannic Kilcher made an excellent video overview.

Breaking down the Transformer

  • Let’s develop intuitions about the architecture by translating the previous paragraph into the language of mathematical symbols and vectors.
  • We update the hidden feature \(h\) of the \(i\) ‘th word in a sentence \(\mathcal{S}\) from layer \(\ell\) to layer \(\ell+1\) as follows:
\[h_{i}^{\ell+1}=\text { Attention }\left(Q^{\ell} h_{i}^{\ell}, K^{\ell} h_{j}^{\ell}, V^{\ell} h_{j}^{\ell}\right)\]
  • i.e.,
\[\begin{array}{c} h_{i}^{\ell+1}=\sum_{j \in \mathcal{S}} w_{i j}\left(V^{\ell} h_{j}^{\ell}\right) \\ \text { where } w_{i j}=\operatorname{softmax}_{j}\left(Q^{\ell} h_{i}^{\ell} \cdot K^{\ell} h_{j}^{\ell}\right) \end{array}\]
  • where \(j \in \mathcal{S}\) denotes the set of words in the sentence and \(Q^{\ell}, K^{\ell}, V^{\ell}\) are learnable linear weights (denoting the Query, Key and Value for the attention computation, respectively).
  • The attention mechanism is performed parallelly for each word in the sentence to obtain their updated features in one shot-another plus point for Transformers over RNNs, which update features word-by-word.
  • We can understand the attention mechanism better through the following pipeline:

  • Taking in the features of the word \(h_{i}^{\ell}\) and the set of other words in the sentence \(\left\{h_{j}^{\ell} \forall j \in \mathcal{S}\right\}\), we compute the attention weights \(w_{i j}\) for each pair \((i, j)\) through the dot-product, followed by a softmax across all \(j\)’s. Finally, we produce the updated word feature \(h_{i}^{\ell+1}\) for word \(i\) by summing over all \(\left\{h_{j}^{\ell}\right\}\)’s weighted by their corresponding \(w_{i j}\). Each word in the sentence parallelly undergoes the same pipeline to update its features.

Multi-head Attention

  • Getting this straightforward dot-product attention mechanism to work proves to be tricky. Bad random initializations of the learnable weights can de-stabilize the training process.
  • We can overcome this by parallelly performing multiple ‘heads’ of attention and concatenating the result (with each head now having separate learnable weights):

    \[\begin{array}{c} h_{i}^{\ell+1}=\text {Concat }\left(\text {head }_{1}, \ldots, \text { head}_{K}\right) O^{\ell} \\ \text { head }_{k}=\text {Attention }\left(Q^{k, \ell} h_{i}^{\ell}, K^{k, \ell} h_{j}^{\ell}, V^{k, \ell} h_{j}^{\ell}\right) \end{array}\]
    • where \(Q^{k, \ell}, K^{k, \ell}, V^{k, \ell}\) are the learnable weights of the \(k^{\prime}\) th attention head and \(O^{\ell}\) is a downprojection to match the dimensions of \(h_{i}^{\ell+1}\) and \(h_{i}^{\ell}\) across layers.
  • Multiple heads allow the attention mechanism to essentially ‘hedge its bets’, looking at different transformations or aspects of the hidden features from the previous layer. We’ll talk more about this later.

Scaling Issues

  • A key issue motivating the final Transformer architecture is that the features for words after the attention mechanism might be at different scales or magnitudes. This can be due to some words having very sharp or very distributed attention weights \(w_{i j}\) when summing over the features of the other words. Additionally, at the individual feature/vector entries level, concatenating across multiple attention heads-each of which might output values at different scales-can lead to the entries of the final vector \(h_{i}^{\ell+1}\) having a wide range of values. Following conventional ML wisdom, it seems reasonable to add a normalization layer into the pipeline.
  • Transformers overcome issue (2) with LayerNorm, which normalizes and learns an affine transformation at the feature level. Additionally, scaling the dot-product attention by the square-root of the feature dimension helps counteract issue (1).
  • Finally, the authors propose another ‘trick’ to control the scale issue: a position-wise 2-layer MLP with a special structure. After the multi-head attention, they project \(h_{i}^{\ell+1}\) to a (absurdly) higher dimension by a learnable weight, where it undergoes the ReLU non-linearity, and is then projected back to its original dimension followed by another normalization:
\[h_{i}^{\ell+1}=\mathrm{LN}\left(\mathrm{MLP}\left(\mathrm{LN}\left(h_{i}^{\ell+1}\right)\right)\right)\]
  • To be honest, I’m not sure what the exact intuition behind the over-parameterized feed-forward sub-layer was. I suppose LayerNorm and scaled dot-products didn’t completely solve the issues highlighted, so the big MLP is a sort of hack to re-scale the feature vectors independently of each other. According to Jannes Muenchmeyer, the feed-forward sub-layer ensures that the Transformer is a universal approximator.
  • Thus, projecting to a very high dimensional space, applying a non-linearity, and re-projecting to the original dimension allows the model to represent more functions than maintaining the same dimension across the hidden layer would. The final picture of a Transformer layer looks like this:

  • The Transformer architecture is also extremely amenable to very deep networks, enabling the NLP community to scale up in terms of both model parameters and, by extension, data. Residual connections between the inputs and outputs of each multi-head attention sub-layer and the feed-forward sub-layer are key for stacking Transformer layers (but omitted from the diagram for clarity).

Are Transformers learning neural syntax?

  • There have been several interesting papers from the NLP community on what Transformers might be learning. The basic premise is that performing attention on all word pairs in a sentence – with the purpose of identifying which pairs are the most interesting – enables Transformers to learn something like a task-specific syntax. Different heads in the multi-head attention might also be ‘looking’ at different syntactic properties.

Why multiple heads of attention? Why attention?

  • I’m more sympathetic to the optimization view of the multi-head mechanism—having multiple attention heads improves learning and overcomes bad random initializations. For instance, these papers showed that Transformer heads can be ‘pruned’ or removed after training without significant performance impact. Wouldn’t it be nice for Transformers if we didn’t have to compute pair-wise compatibilities between each word pair in the sentence?
  • Could Transformers benefit from ditching attention, altogether? Yann Dauphin and collaborators’ recent work suggests an alternative ConvNet architecture. Transformers, too, might ultimately be doing something similar to ConvNets!

Why is training Transformers so hard?

  • Reading new Transformer papers makes me feel that training these models requires something akin to black magic when determining the best learning rate schedule, warmup strategy and decay settings. This could simply be because the models are so huge and the NLP tasks studied are so challenging.
  • But recent results suggest that it could also be due to the specific permutation of normalization and residual connections within the architecture.

Further reading

References

Citation

If you found our work useful, please cite it as:

@article{Chadha2020DistilledTransformers,
  title   = {Transformers},
  author  = {Chadha, Aman},
  journal = {Distilled AI},
  year    = {2020},
  note    = {\url{https://aman.ai}}
}