Attention mechanism [1,2] improved NLP architectures by allowing them to focus on a relevant part of input/representation similar to how we humans do. While reading a text if the first and last character of a word is correct, humans can understand the text [3]. This post examines the inner working of additive and multiplicative attention, i.e. How attention mechanism converts query and hidden states into attention scores.
Introduction and Motivation
The attention enables model to focus on the relevant part of input at a given timestamp. We humans use the attention mechanism in our daily life for pretty much everything. When trying to answer a question based on comprehension passage, we look for important sentences to answer the question similar to attention mechanism. The attention mechanism was introduced by Bahdanau[2] in sequence to sequence models used for translation. While generating translation decoder will focus on parts of input which are relevant to word being generated. It solved the issue of bottleneck of using only final state of encoder in sequence to sequence models and also resulted in model being able to handle longer sentence in translations.
There are various type of attention mechanisms and a good summary is available on Lillian Weng’s Blog [3], We will only discuss additive and dot product attention and their implementations.
Basic components
Three basic ingredients in attention are:
-
Query : The query provides the context into what we are currently after to decide where to focus in sentence. In seq2seq model, decoder state from last time step is used as query to guide what is generated next.
-
Key : The key can be multiple states generated by seq2seq encoder model which will be matched with query while decoding.
-
Value : Value represents actual value which can be selected. This can be same as key and mostly is.
In most of the cases keys will refer to the multiple available states by encoder.
Additive Attention (Bahdanau)
Additive attention use the decoder state from last time step as query(s_t) and keys are encoder states(h_t) for the whole sequence. This mechanism adds query to each of encoder state and hence the name additive attention.
This procedure in Bahdanau’s attention mechanism is mathematically shown as $$ query = W_q.s_t $$ $$ keys = [W_k.h_i \quad i \in [1 \dotsc T]] $$ $$ scores = v_a.tanh(query + key) $$
Transform Query and hidden states using projection layer and merge
We first transform query and hidden state using a weight embedding (linear layer) to similar dimension. This step is important because query and hidden states might have different dimension so to compare both of them we need to transform into a similar space using projection.
Define the Sizes of input key, Query, values
# Imports
import torch
import torch.nn as nn
# Size of hidden dimension of model.
hidden_dim = 128
# Size of encoder, considering encoder was bi-directional.
key_size = 2*hidden_dim
# Decoder is not birectional hence not multiplied by 2.
query_size = hidden_dim
value_size = hidden_dim
batch_size = 64
max_seq = 20
Define the dummy key and query and projection layers.
## Using the dummy key, query, values
key = nn.randn(batch_size, max_seq, key_size)
query = nn.randn(batch_size, query_size)
## Define projection layer to map key and query to same dimention
key_layer = nn.Linear(key_size, hidden_dim, bias=False)
query_layer = nn.Linear(query_size, hidden_dim, bias=False)
The query dimension are (batch_size, query_size) where as keys are (batch_size, max_seq, key_size). We need to expand the dimension of query to account for sequence lengths which is done using unsqueeze function.
## torch.size([64, 128]) => torch.size([64, 1, 128])
query.unsqueeze(1)
# We need to
## Doing the actual projection on input
proj_key = key_layer(key)
proj_query = query_layer(query)
# Pass it through non-linearity
projected_input = torch.tanh(proj_key + proj_query)
Above is the original implementation by Bahdanau [2] but we can also implement it by concatinating the key and query together and then passing through linear layer which is better as it allows finer-grain merging of states from query and key. Hence better matching and attention scores. $$ W_a[s_t;h_i] \quad i \in [1 \dotsc T] $$
# Keys shape is batch_sizeXmax_seqXhidden_dim
key = torch.randn(batch_size, max_seq, hidden_dim)
proj = nn.Linear(2*hidden_dim, hidden_dim)
# Query shape is batch_sizeXhidden_dim
# Add extra dimension for sequence length to query
# batch_sizeXhidden_dim => batch_sizeX1Xhidden_dim
query = query.unsqueeze(1)
# batch_sizeX1Xhidden_dim => batch_sizeXmax_seqXhidden_dim
query = query.expand_as(key)
# Concatinate Query and Keys
merged = torch.cat((key, query), dim=2)
# Now shape of merged is batch_sizeXmax_seqX(2*hidden_dim)
# use projection layet o project to input
projected_input = proj(merged)
projected_input = torch.tanh(projected_input)
Project down the merged representation and calculate attention scores
Above step provides us with matched representation of query from decoder and keys from encoder but this is not attention score are values between 0 and 1 reflecting the contribution of each key at current time. This is implemented using another projection referred by v_a which project down the hidden state and and then passing through softmax layer to generate probability distribution on keys.
Mathematically this projects the representation further $$ v_a.\mathrm{tanh}(W_a[s_t;h_i]) \quad i \in [1 \dotsc T] $$
# Projection to lower dimension
enery_layer = nn.Linear(hidden_dim, 1)
# Output of energy is batch_sizeXmax_seqX1
energy = enery_layer(projected_input)
# Dimension of energy are 64X20X1
# Since attention is probability distribution across all states.
# We convert it to dimension batch_sizeX1Xmax_seq so that multiplying
# attention scores to value of states gives us correct state formed after applying attention
energy = energy.squeeze(2).unsqueeze(1)
# energy is are not 64X1X20
# Pass through softmax layer
attention_scores = F.softmax(energy, dim=-1)
#alphas is still 64X1X20 but now all values sums to 1.
attended_state = attorch.bmm(attention_scores, values)
# attention shape is now 64X1X128
Self-Attention (Vaswani)
Self attention is a multiplicative (dot-product) attention, attention scores are obtained by multiplying query and key. We will look further into how query and key are transformed to generate the attention scores.
Let’s asssume our sentence is of length 10, our attention scores should look like [10X10] matrix which define how each word in dependent on all the other words in the sentence.
Self-attention contains multiple heads and scaled dot-product attention. Below we briefly explain what does multiple head mean and why scaling is needed in attention.
Multi-head attention.
In a sentence there might be multiple dependencies between a word and other words. Normal scaled dot-product attention procedure is repeated multiple times which provides the modle to capture multiple types of dependencies between the words. A recent paper the story of heads [6] present evidence which shows different heads have different rols of semantic and syntactic roles. To make the model computationally efficient, they internal dimension is much smaller than input and output dimension of model and following dimension was used in the paper.
$$ d_{head} = d_{key} = d_{value} = d_{query} = \frac{d_{model}}{num\_heads} $$
This is the reason using Transformer in pytorch with model dimension which isn’t multiple of number of heads gives following assetion.
In [11]: torch.nn.TransformerEncoderLayer(d_model=128, nhead=9)
AssertionError: embed_dim must be divisible by num_heads
Scaled dot-product
Dot-product attention in the paper was scaled by \(\frac{1}{\sqrt{d_{key}}}\). Attention for a singl head is calculated using following equation: $$ softmax(\frac{Q.K^T}{\sqrt{d_{key}}})$$
In the paper, they observed mlarger \(d_{model}\) values make the dot-product higher resulting in very small gradient which could be the reason for additive attention to perform better than multiplicative attention.
The Smaller gradient is attached to kind of probablity distribution softmax creates when the magnitude of input vector is higher. When the magnitude is higher softmax creates peaked distibution which results in most of the elements closer to zero hence resulting in smaller magnitude. The higher magnitude is result of higher variance coming from sum in dot-product attention.
Step by step multi-head self attention explaination
Step 1: Use Embedding to generate Query, Key, Value from input
In self-attention query and key both are same input value, where as query can be different from key as generally as it was in additive attention shown above. The Self-attention uses a linear projection to create query, key and value from input to the layer. We can either use the shared projection layer for key, query and value or different projection for each of them depending on parameter budget and need. Pytorch implementation defaults to shared projection layer.
# Import torch
import torch
# initalize model dimension and heads for transformers
d_model = 1024
num_heads = 8
d_head = d_model//8
batch_size = 128
seq_len = 10
inp = torch.randn(batch_size, seq_len, d_model)
# Define the projection query, key and value.
# Since internal dimension is each head is d_head
# we should have `num_head` projection laters and then concatinate.
# Defining it like this achieves similar results
# and is computationally efficient and convinent.
proj_query = torch.nn.Linear(d_model, d_model)
proj_key = torch.nn.Linear(d_model, d_model)
proj_value = torch.nn.Linear(d_model, d_model)
# Do the actual projection to create query, value and key.
Query = proj_query(inp)
# Query dimension is torch.Size([128, 10, 1024]) == [batch_size, seq_len, inp]
Key = proj_key(inp)
Value = proj_value(inp)
Step 2: Rearrange tensors to accound for multiple heads.
Following the creation of query, we need to account for multiple heads used by transformers attention. Query we created in last step contains input for all the heads as dim_head is \(\frac{d_{model}}{num\_heads}\). We have to rearrange the tensor so that heads are broken from last dimension and move to dimension before seq_len so that parallel computation is performed for each of the head while doing matrix multiplication.
# Transform the heads out of last dimension and merge them to batch
Query = Query.contiguous().view(seq_len, batch_size*num_heads, d_head)
# Query dimension is [10, 128*8, 128] = [seq_len, batch_size*num_heads, d_head]
Query = Query.transpose(0,1)
# Query dimension now is [1024, 10, 128] == [batch_size*num_heads, seq_len, d_head]
# This enables parallel processing of num_heads as those goes into batch dimension while
# doing the tensor multiplication.
# Similar transformation is done for Key and Vaue
Key = Key.contiguous().view(seq_len, batch_size*num_heads, d_head)
Key = Key.transpose(0,1)
Value = Value.contiguous().view(seq_len, batch_size*num_heads, d_head)
Value = Value.transpose(0,1)
Step 3: Generate scaled attention scores
Generating attention scores corresponds to scoring how each word is depenedent on all the other words and output would look like a matrix of torch.Size([Seq_len, Seq_len]). We transpose Key from \([Seq\_lenXd\_head]\) to \([d\_headXSeq\_len]\). This is transposing the Key vector and helps us generate and capture the interaction between each word in query to each word with every other word. This is most important part and generates the attention score which after passing through softmax is referred to as self-attention.
Key = Key.transpose(1,2)
# Dimension of Key now is [batch_size*num_heads, d_head, seq_len]
softmax = torch.nn.Softmax(dim=-1)
Raw_Attention = torch.bmm(Query, Key)
attention_score = softmax(Raw_Attention)
Step 4: Preparing final output preperation from multi-head attention
Multiplying attention score to value generates the raw output from each of the heads which passed through a linear output projection layer and rearranged to generate final output from multi-head self-attention module, output dimension is similar to input dimension.
# Generates the output value by multiplying attention scores to value
raw_output = torch.bmm(attention_scores, Value)
# raw_output dimensions are [1024, 10, 128] = [batch_size*num_heads, seq_len, d_head]
# Define output projection layer which projects the output back to input dimension of layer
output_proj = nn.Linear(d_model, d_model)
# Rearrange the heads to be concatinated back
raw_output = raw_output.contiguous().view(batch_size, seq_len, d_model)
# raw_dimension dimension now is [128, 10, 1024]
# Project the raw output which is output from the single multi-head attention module.
output = output_proj(raw_output)
This conclude introduction to multiplication (dot product) attention.
Reference
- [1] Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems. 2017.
- [2] Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. “Neural machine translation by jointly learning to align and translate.” arXiv preprint arXiv:1409.0473 (2014).
- [3] https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
- [4] https://www.foxnews.com/story/if-you-can-raed-tihs-you-msut-be-raelly-smrat
- [5] [Annotated Attention] (https://nlp.seas.harvard.edu/2018/04/03/attention.html#attention)
- [6] The Story of Heads