swin transformer源码解读
2020 年 5 月,Facebook AI 推出了DERT( Detection Transformer),用于目标检测和全景分割。
2020 年 10 月,谷歌提出了Vit(Vision Transformer),利用 Transformer 对图像进行分类,而不需要卷积网络。
2021年1月,OpenAI 提出两个模型:DALL·E 基于本文直接生成图像,CLIP将图像映射到文本描述的类别中。两个模型都利用 Transformer 。
2021年3月,微软提出Swin Transformer,把CV各大任务给屠榜了。。。。
我能放过它?我不能。。。总结下前段时间看了论文和代码梳理出来的swin_transformer框架和实现。
论文: https://arxiv.org/abs/2103.14030
代码: https://github.com/microsoft/Swin-Transformer
swin_transformer介绍
1. swin_transformer优化点
swin_transformer对比之前Vit有两个改进点:
1.引入了CNN里常用的多层次transformers结构
Vit的尺度是不变的,不易于接入到下游任务中,比如分割的encoder阶段可以方便的接入resnet等backbone网络,而Vit的特征图尺寸是不变的下图(b)。swin_transfomer通过合并image_patchesd的方式引入多层次结构,如下图(a)。
2. swin_transformer如何优化
针对第一个优化点,论文使用的网络架构如下:
代码模块逻辑:
patch_embed + pos_embed
stage1
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage2
-BasicLayer
--SwinTransformerBlock(*2)
---WindowAttention
stage3
-BasicLayer
--SwinTransformerBlock(*6)
---WindowAttention
stage4
-BasicLayer
--SwinTransformerBlock(*4)
---WindowAttention
主要模块的代码逻辑:
1.patch_embed:PatchEmbed
首先进行一次patch_embed,patch_embed就是把输入按patch进行一次向量映射。我认为就是卷积操作(标题swin_transfomer,第一步就是卷积~卷积yyds)
设定输入:(3,256,256),patch_size=4,embeding_dim=96
(1)分辨率不够4整除就pad到4的倍数
(2)通用卷积kernel=4,stride=4,将image映射为无重叠的4*4的patchs:(96,64,64)
(3)如果需要norm,再进行一次layerNorm
(4)(3,256,256) 通过patch_embed,特征为(96,64,64)
2.absolute_pos_embed
如果有position_embeding步骤,需要学习一个96,64,64的pos_emded参数。和patch_embed进行concat.
将emded矩阵进行flatten+transpose-->64*64, 96
3.stages
对分辨率缩小*4的特征图进行4个stage的-BasicLayer
BasicLayer
1.attn_mask
设定window_size=7,以stage1为例输入特征图大小为(64,64)。img_mask初始为(70,70),那么通过window_partition就把特征图切分为100个7*7的窗口。
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask:, h, w, : = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
以上代码目的是得到100个49*49的attn_mask。
这里的attn_mask是为后续的cyclic shift,也就是SW-MSA使用。
首先,对img_mask70*70的图进行切分9大块赋值
63*63=0 4*63=1 3*64=2
63*4=3 4*4=4 3*4=5
64*3=6 4*3=7 3*3=8
2.SwinTransformerBlock(*n)
(1)reshape+pad
对输入64*64, 96进行layer_norm+reshape+pad操作。pad作用是要FM的H,W是window_size的倍数。对stage1:64*64, 96-->70,70,96
(2)window_mask_self_attention(W-MSA/SW-MSA)
先看第一阶段W-MSA blcok,也就是不加入cyclic shift。
(a)进行window_partition,将特征图切分为window_size*window_size的patch,1,70*70,96切分为100,7,7,96,再reshape100,49,96
(b) WindowAttention
计算self_attention
然后在X和Y方向计算relative_coords。计算relative_coords第一步加(window_size-1)是为了让值都为正数,在X方向再*(2*window_size-1)是为了后续求和能区分(0,1)和(1,0)这类坐标。
(b)windowAttention
计算attention和上诉步骤一致,只是在步骤a中我们提到了,ABC区域在计算attention时需要mask掉,这里的mask就是我们BasicLayer的第一步获取的attn_mask(100,49,49)~
if mask is not None:
nW = mask.shape0
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
mask主要逻辑,attn假设目前是200,3,49,49,我们计算的attn_mask是(100,49,49),因为是针对窗口位置mask和bs和head_num无关,所以将attn和mask分别reshape到(2, 100, 3, 49, 49)和(1,100,1,49,49)就好了。
最后记得window_rever后,记得把shift_x给sereverse回去。
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
以上就将最复杂的SwinTransformerBlock模块介绍完了~
3.down_sample
downsamp(最后一个stage不需要)使用的是PatchMerging.对FM进行间隔采样达到降采样的目的,再concat低分辨率FM后,通过全连接对C通道裁剪。很像pixelShuffle的反向操作。
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = x.view(B, H, W, C)
padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x:, 0::2, 0::2, : # B H/2 W/2 C
x1 = x:, 1::2, 0::2, : # B H/2 W/2 C
x2 = x:, 0::2, 1::2, : # B H/2 W/2 C
x3 = x:, 1::2, 1::2, : # B H/2 W/2 C
x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
以上就是一个basicLayer的逻辑,通过四个stage得到不同尺度的特征图(Swin-T)
stage1-->96, 64, 64
stage2-->192, 32, 32
stage3-->384, 16, 16
stage4--> 768, 8, 8
有了这个四个特征图就可以和resnet等结构一样,接入到下游任务了~