Rediscovery Journal (20May’23): Learning Transformer (and Vision Transformer)

5 minute read

Transformer has been one of the big turning point in Deep learning history, on par with ImageNet, CNN, GAN, and attention modules. I’ve never gotten the opportunity to experiment and deploy them in my previous workplace, so this is a good chance to get myself up to date.

All codes are based off several reference notebooks which are included in the footnotes. All my work are also posted on github here in this link.

First, we want to list out some of the most important features of a Transformer setup:

  1. Positional Encoding

Unlike RNN, self-attention multi-head module do not posses recurrent networks. How then, do you store sequential order of input within memory? By using positional encoding, it makes sure that the relative positions of the tokens play a role in the output of the model.

In the original transformer paper and notebook, the author proposed an absolute positional encoding based off the derivative of a sin and cosine function.

$$ PE(pos,2i)=sin(pos/100002i/dmodel)

$$

$$ PE(pos,2i+1)=cos(pos/100002i/dmodel) $$

where pos is the position of token in tome and dmodel is the number of dimension of embeddings and i being dimension index of input tensor. A further explanation on why use both sin and cos function is explained here

In a vision transformer, positional encoding is represented as a learnable parameter. The way this is done is that first an image is broken down into patches, and each row vector represents a patch of nxn size. Each patch is then considered as a single position in the positional embeddings.

1self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches, embed_dim))

this showcases the initialization of the positional embedding that with the shape (num_patches+1, embed_dim), where embed_dim would be 256 for total number of colour values in a single pixel.

Treating images as tokens in a NLP problem

One of the first and foremost problem that needs to be solve for transformer to be applicable to 2-dimensional data such as images is representation. In its original paper of vision transformer, images are broken down into patches of smaller images, and each patch is being treated as a “token”.

In the vision transformer code, a class embedding token (CLS token) is added prior to the positional embeddings, and is responsible of the final output classification. As to why this was done, the answer was the author admitted that this was only added in order to be consistent with NLP transformers. You may learn more about it from this stackoverflow question page

1# Add CLS token and positional encoding
2cls_token = self.cls_token.repeat(B, 1, 1)
3# TODO: why concat? isnt it just a sum
4x = torch.cat([cls_token, x], dim=1)
5x = x + self.pos_embedding[:,:T+1]

There has also been attempts to replace cls_token + learnable embeddings with conditional positional encodings with an added global max pooling at the end prior to the MLP module. The paper argues that it has an increase of AP from 33.7 to 33.9. You may learn more about it here

  1. Self-attention module
 1
 2def attention(query, key, value, mask=None, dropout=None):
 3    # SDPA implementation
 4    # (Scale Dot-Product Attention)
 5    d_k = query.size(-1)
 6    scores = torch.matmul(query, key.Transpose(-2,-1)) \
 7             / math.sqrt(d_k)
 8    if mask is not None:
 9        scores = scores.masked_fill(mask==0,-1e9)
10    p_attn = F.softmax(scores, dim=-1)
11    if dropout is not None:
12        p_attn = dropout(p_attn)
13    
14    return torch.matmul(p_attn, value), p_attn
15
16class MultiHeadedAttention(nn.Module):
17    def __init__(self, h, d_model, dropout=0.1):
18        super(MultiHeadedAttention, self).__init__()
19        assert d_model % h == 0
20        # assume d_v always equals d_k
21        self.d_k = d_model // h
22        self.h = h
23        self.linears = clones(nn.Linear(d_model, d_model), 4)
24        self.attn = None
25        self.dropout = nn.Dropout(p=dropout)
26        
27    def forward(self, query, key, value, mask=None):
28        if mask is not None:
29            mask = mask.unsqueeze(1)
30        nbatches = query.size(0)
31        
32        # 1) Do all linear projectsion in batch
33        # d_model => h x d_k
34        query, key, value = \
35            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
36             for l,x in zip(self.linears, (query, key, value))]
37        
38        # 2) Apply attention on all projected vectors in batch
39        x, self.attn = attention(query, key, value, mask=mask, 
40                                 dropout=self.dropout)
41        
42        # 3) concat using a view and apply a final linear
43        x = x.transpose(1,2).contiguous() \
44              .view(nbatches, -1, self.h*self.d_k)
45        
46        return self.linears[-1](x)

The self attention modules are implemented three times across the architecture. The first for capturing the input embeddings (encoder), a second for output embeddings (decoder), and a third for both input and output embeddings (encoder-decoder).

To understand attention module better, it is important to dig deeper into the concept of Query, Key, and Value. Using looking for a youtube video as an analogy, a Query would be your search string onto the search bar, Youtube will then uses these queries to find for the best matching Key, and the results returned will be the Value. (a deeper explanation can be found here)

An implementation of such code is shown as such:

 1def attention(query, key, value, mask=None, dropout=None):
 2    # SDPA implementation
 3    # (Scale Dot-Product Attention)
 4    d_k = query.size(-1)
 5    scores = torch.matmul(query, key.Transpose(-2,-1)) \
 6             / math.sqrt(d_k)
 7    if mask is not None:
 8        scores = scores.masked_fill(mask==0,-1e9)
 9    p_attn = F.softmax(scores, dim=-1)
10    if dropout is not None:
11        p_attn = dropout(p_attn)
12    
13    return torch.matmul(p_attn, value), p_attn

How are they represented?

Untitled

In Input Encoder, the input embeddings and positional encodings are passed to the attention module. Positional Encodings are summed along with input embeddings, and is the sole source for query, key, and value parameters in the attention head, hence called self-attention.

 1class EncoderLayer(nn.Module):
 2    def __init__(self, size, self_attn, feed_forward, dropout):
 3        super(EncoderLayer, self).__init__()
 4        self.self_attn = self_attn
 5        self.feed_forward = feed_forward
 6        self.sublayer = clones(SublayerConnection(size, dropout),2)
 7        self.size = size
 8        
 9    def forward(self, x, mask):
10        x = self.sublayer[0](x, lambda x: self.self_attn(x,x,x, mask))
11        return self.sublayer[1](x, self.feed_forward)

In the Output Encoder, we have two kinds of Attention Module. The first takes in output embedding, and the second takes in output embedding + source embedding as source of its attention calculation. Memory here refers to

 1class SublayerConnection(nn.Module):
 2    """Residual connection followed by layer norm
 3    """
 4    def __init__(self, size, dropout):
 5        super(SublayerConnection, self).__init__()
 6        self.norm = LayerNorm(size)
 7        self.dropout = nn.Dropout(dropout)
 8        
 9    def forward(self, x, sublayer):
10        "Apply residual connection to any sublayer with same size"
11        return x + self.dropout(sublayer(self.norm(x)))
12
13class DecoderLayer(nn.Module):
14    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
15        super(DecoderLayer, self).__init__()
16        self.size = size
17        self.self_attn = self_attn
18        self.src_attn = src_attn
19        self.feed_forward = feed_forward
20        self.sublayer = clones(SublayerConnection(size, dropout), 3)
21        
22    def forward(self, x, memory, src_mask, tgt_mask):
23        m = memory
24				# first attention module with qkv from the output embed
25        x = self.sublayer[0](x, lambda x: self_attn(x,x,x, tgt_mask))
26				# second attention module with input result and output result as their source 
27        x = self.sublayer[1](x, lambda x: self.src_attn(x,m,m,src_mask))
28        
29        return self.sublayer[2](x, self.feed_forward)

the-annotated-transformer_14_0.png

Masking

The third property we need to learn is Masking. Masking is added onto the output embedding (target vector) to prevent the attention module to overfit itself to the whole target vector.

1def subsequen_mask(size):
2    attn_shape = (1, size, size)
3    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
4    
5    return torch.from_numpy(subsequent_mask) == 0