All you need to know about Attention

read 8 mts

Attention Mechanism in deep learning has led to major performance improvements both in NLP and computer vision. All sequence models now implement attention resulting in state-of-the-art architectures. Let’s try to understand the evolution of attention mechanism.

Prior to attention models, many NLP tasks like language translation, question answering, text summarization were using the sequential encoder-decoder structure.

Here, an encoder is a Recurrent Neural Network (RNN) cell, usually an LSTM model that takes input sentence in the form of word tokens one at a time. Each input word creates a hidden state which is passed along with the next input word. The next hidden state generated now holds the information of both the previous words. This continues throughout the length of the input sequence until the last hidden state is generated. The last hidden state which carries information about the entire sequence is called as the context vector which is passed as the input to the decoder unit.

A decoder is also an RNN cell that takes in the context vector and generates the output token one at a time. It takes the hidden state and output of the previous decoder step and generates the output at this step.

While encoder-decoder architecture is a major improvement over the traditional probabilistic approach used earlier, it still has two limitations. First, the encoder creates a single fixed-length context vector for the entire sentence. This doesn’t work well when the sequence length becomes large. Although increasing the length of the context vector is a solution, training for such a long context vector becomes computationally very expensive. Also, if the input is short for a long context vector, it fails to extract a fair representation.
The second limitation is the sequential nature of RNN architecture. It needs every word token to be an input one at a time and therefore, takes a long time to train.

Attention Mechanisms address both of these limitations with a simple yet powerful idea. The idea is that humans don’t read an entire sentence in one go. The eye focusses on a small portion of text in ‘high resolution’ while viewing the remaining text in ‘low resolution’ continually adjusting its vision.

Attention Mechanisms try to emulate this pattern. In an attention architecture, not just the last but all of the hidden states of the encoder are passed to the decoder. The decoder now uses a function to generate the weight for each hidden state from the encoder. Obviously, the weight is higher for the encoder hidden state that the decoder wants to focus on. These weights are also called as the alignment vector.

The alignment for an English to French translation sentence can be seen below:

Visual Representation of Alignment Matrix [2]

It is interesting to see how the model learns the weights correctly for the words ‘European’ (first matrix) and ‘environment’ (second matrix) even though the context position for these words are not the same between English and French language.

So, rather than one long context vector, the decoder now maintains an active connection with every encoder state and can read from any part of the sequence at every decoder step. Therefore the context vector, for every decoder step i calculated separately, is a combination of one or more hidden states defined as:        

Where h_j is the hidden state for the encoder step j and, weight α_ij  is the alignment weight, soft-max of the score e_ij for the encoder step j.

Now that we know what attention is, let’s see the different functions as introduced in various attention papers.
There are two types of attention – additive and multiplicative.

Additive Attention introduced in the 2014 Bahdanau paper computes the alignment score e_ij as a function of hidden state of previous decoder step s_(i-1) and the hidden state matrix of the all encoder h_j passed through a feed forward layer.

Bahdanau architecture has a bi-directional LSTM as an encoder, so h_j is the concatenated vector of the forward and backward hidden states.

Bahdanau Paper 2014. [1]

Multiplicative Attention introduced in the 2015 Luong paper provides three different forms of the score function:

Luong Paper 2015. [2]

Here, h_t is the hidden state of the decoder and ¯(h_s) is the encoder hidden state matrix. The first two functions are a simple dot product. While the first dot product expects the hidden state of encoder and decoder to have the same dimensions, the second dot product involves a weight matrix multiplied by the decoder matrix which transforms it to the dimensional space of decoder. The third function takes a concatenation of the hidden state from both encoder and decoder and passes it through a feedforward layer to calculate the score. If this seems very similar to Bahdanau paper, that’s because Bahdanau function uses the previous state of decoder while Luong concat function uses the current state. Bahdanau also as an extra weight matrix for encoder hidden states.

Similar to Additive Attention, Multiplicative Attention also calculates soft-max of the scores and produces context vector c_i with the same formula. But, the final output is taken from the feedforward layer which takes as input the concatenated vector of context vector and decoder hidden state at that step.

While this is termed as the Global Attention Model, Luong Paper also introduced something called Local Attention Model which reduces computation further by picking an aligned position and then creates context vector from encoder states that are D positions left and right of the aligned position [p_t+D,p_t-D]. This also means that the context vector is fixed in length in case of local Attention.

Local Attention Model, luong paper [2]
Global Attention Model, luong paper [2]

Now we know how attention overcomes the limitations of a similar and single fixed-length context vector for each decoder step.

Another paper released in 2017 named Attention is all you need, takes the idea of attention a step further. It proposes an architecture which uses only attention and no sequence RNNs in encoder and decoder, called the Transformer. As transformers take inputs simultaneously and not sequentially, they add a positional element to the input embedding to maintain the position of each token in the sequence. This positional embedding is learned during training, also called as Multi-Head Attention.

This image has an empty alt attribute; its file name is at7.jpg
Multi-Head Attention [3]

As part of multi-head attention, the model learns three matrices Query (Q), Key (K) and Value (V), which are a linear combination of the input embeddings. The new input embeddings A(Q, K, V) with the positional context of the sequence is a function of these learned matrices.

The function reflected in the multi-head diagram multiplies the Query vector of the encoding token with Key vector of all other encoding tokens to get their scores. The score is scaled by the square root of the dimension of Key vector, converted to probability using soft-max and multiplied by Value vector. All the Value vectors are concatenated together and passed to the feedforward layer. The final embedding now carries information not only for the contextual meaning of the word but, also its position in the sequence.

Transformer, as demonstrated in the paper, has 6 units of encoders (on the left) and 6 units of decoders (on the right), all stacked together. It further uses three kinds of Multi-Head Attention models.

 Transformer Architecture [3]

The encoder multi-head attention is self-attention in which every encoder unit learns from the state of every other encoder unit simultaneously, to learn the structure of the sequence.

The decoder masked multi-head attention masks the states of the future decoder outputs and allows the model to learn only from the previous decoder output states.

The decoder multi-head attention learns from all the encoder states and the output of decoder masked multi-head attention.

Both encoder and decoder have a feedforward layer with 512 hidden nodes to ensure that the size is consistent at every step. The output from every layer is added with a residual connection and normalized before passing as input for the next step. The final probability function is similar to the Multiplicative Concat Model where the score is calculated after passing through a linear layer and soft-max gives the probability between 0 an 1.

Not only did this architecture result in enhanced speed and reduced computation but also a state-of-art performance on many NLP applications.

Transformers are now the building blocks of many state-of-the-art language representation architectures like OpenAI GPT and BERT.

Hope this article was helpful in giving a fair idea about the various attention mechanisms that are known today.

Any questions, feedback, suggestions for improvement are most welcome. 🙂

If you are intrigued by this kind of exciting work and would like to explore your avenues in the AI space, get in touch with our counselors to help you take your career ahead.

References

3+

Leave a Reply

Your email address will not be published. Required fields are marked *