本文将以模型解释与代码结合起来,来向大家解释Vision Transformer Model(ViT)
论文地址如下:
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
论文代码
论文模型总览
我们将图像分割成固定大小的块,线性嵌入每个块,添加位置嵌入,并将生成的向量序列馈送到标准的Transformer
编码器。为了执行分类,我们使用向序列添加额外可学习的“分类标记”的标准方法。
分块与降维
对于图像来讲,一般是三通道(BGR256)的彩色图片,但是我们想用现成的Transformer
模型对图像进行处理,那么我们就应该对图像进行降维处理。
首先把x p ∈ R H × W × C \mathbf{x}_{p} \in \mathbb{R}^{H \times W \times C} x p ∈ R H × W × C 的图像,变成一个x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_{p} \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)} x p ∈ R N × ( P 2 ⋅ C ) 的sequence of flattened 2D patches 。其可视为一系列的展平的2D块的序列,这个序列中一共有N = H W P 2 N = \frac{ HW }{P^{2}} N = P 2 H W 个展平的2D块,N N N 即为Transformer
输入的sequence
的长度。其中每个块的维度是( P 2 ⋅ C ) \left(P^2 \cdot C \right) ( P 2 ⋅ C ) ,其中H H H 和W W W 是图像的高和宽,P P P 是块大小,C C C 是图片的通道数。
那么这一步在代码中是怎么做的呢?
我们通过from einops import rearrange
来解决这个问题
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)' , p1=p, p2=p)
实例:
1 2 3 4 5 6 7 8 img.shape p = 2 x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)' , p1=p, p2=p) x.shape
参考文章:einops:优雅地操作张量维度
Patch Embedding
上一步我们已经得到了一个二维向量x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_{p} \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)} x p ∈ R N × ( P 2 ⋅ C ) ,要转化为( N , D ) \left(N, D \right) ( N , D ) ,我们需要进行Patch Embedding ,即相当于NLP中的word embedding,做一个线性变换(即全连接层)。具体操作如下:
图像共切分为N = H W P 2 N = \frac{HW}{P^{2}} N = P 2 H W 个patches,这便是sequence的长度,需要注意的是这里直接将Patch拉平为1-D,其特征大小为P 2 ⋅ C P^{2} \cdot C P 2 ⋅ C ,然后通过一个线性变换奖patches映射到D大小的维度,这就是patch的embeddings。(这等同于对x \mathbf{x} x 做一个P × P P \times P P × P 的且stride为P的卷积操作)
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s \mathbf{z}_{0}=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{p o s}
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s
全连接层就是(1)式中的E \mathbf{E} E ,它的输入维度大小是( P 2 ⋅ C ) \left( P^2 \cdot C \right) ( P 2 ⋅ C ) ,输出维度大小是D D D 。
1 2 3 self.patch_to_embedding = nn.Linear(patch_dim, dim) x = self.patch_to_embedding(x)
注意这里的x class \mathbf{x}_{\text {class }} x class ,假设切成9个块,但是最终到Transfomer输入是10个向量,这是人为增加的一个向量。
这么做的原因可以理解为:ViT其实只用到了Transformer的Encoder,而并没有用到Decoder,而x class \mathbf{x}_{\text {class }} x class 的作用有点类似于解码器中的Query 的作用,相对应的Key, Value 就是其他9个编码向量的输出。
x class \mathbf{x}_{\text {class }} x class 是一个可学习的嵌入向量,它的意义说通俗一点为:寻找其他9个输入向量对应的img 的类别。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class PatchEmbed (nn.Module ): """ Image to Patch Embedding """ def __init__ (self, img_size=224 , patch_size=16 , in_chans=3 , embed_dim=768 ): super ().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1 ] // patch_size[1 ]) * (img_size[0 ] // patch_size[0 ]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward (self, x ): B, C, H, W = x.shape assert H == self.img_size[0 ] and W == self.img_size[1 ], \ f"Input image size ({H} *{W} ) doesn't match model ({self.img_size[0 ]} *{self.img_size[1 ]} )." x = self.proj(x).flatten(2 ).transpose(1 , 2 ) return x
Position Encoding
按照Transformer的位置编码的习惯,这个工作也使用了位置编码。引入了一个 Positional encoding E p o s \mathbf{E}_{p o s} E p o s 来加入序列的位置信息 ,同样在这里也引入了pos_embedding ,与最初的Transformer模型不同,这是用一个可训练的变量 。
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s \mathbf{z}_{0}=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{p o s}
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s
1 2 3 4 5 self.pos_embedding = nn.Parameter(torch.randn(1 , num_patches + 1 , dim)) x = x + self.pos_embed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def resize_pos_embed (posemb, posemb_new ): _logger.info('Resized position embedding: %s to %s' , posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1 ] posemb_tok, posemb_grid = posemb[:, :1 ], posemb[0 , 1 :] ntok_new -= 1 gs_old = int (math.sqrt(len (posemb_grid))) gs_new = int (math.sqrt(ntok_new)) _logger.info('Position embedding grid-size from %s to %s' , gs_old, gs_new) posemb_grid = posemb_grid.reshape(1 , gs_old, gs_old, -1 ).permute(0 , 3 , 1 , 2 ) posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear' ) posemb_grid = posemb_grid.permute(0 , 2 , 3 , 1 ).reshape(1 , gs_new * gs_new, -1 ) posemb = torch.cat([posemb_tok, posemb_grid], dim=1 ) return posemb
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E pos , E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D z ℓ ′ = MSA ( LN ( z ℓ − 1 ) ) + z ℓ − 1 , ℓ = 1 … L z ℓ = MLP ( LN ( z ℓ ′ ) ) + z ℓ ′ , ℓ = 1 … L y = LN ( z L 0 ) \begin{array}{ll}
\mathbf{z}_{0}=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{\text {pos }}, & \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{p o s} \in \mathbb{R}^{(N+1) \times D} \\
\mathbf{z}_{\ell}^{\prime}=\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, & \ell=1 \ldots L \\
\mathbf{z}_{\ell}=\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, & \ell=1 \ldots L \\
\mathbf{y}=\operatorname{LN}\left(\mathbf{z}_{L}^{0}\right) &
\end{array}
z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E pos , z ℓ ′ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 , z ℓ = M L P ( L N ( z ℓ ′ ) ) + z ℓ ′ , y = L N ( z L 0 ) E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D ℓ = 1 … L ℓ = 1 … L
其中,第1个式子为上面讲到的Patch Embedding 和Positional Encoding 的过程。
第二个式子为Transformer Encoder 的Multi-head Self-attention ,Add and Norm 的过程,重复L次
第二个式子为Transformer Encoder 的Feed forward Network,Add and Norm 的过程,重复L次
最后是一个MLP 的Classification Head ,整个的结构只有这些。