Swin Transformer实战:从零构建SwinT-Tiny模型的PyTorch实现指南
1. 环境准备与模型架构概览
在开始构建Swin Transformer之前,我们需要配置适当的开发环境并理解模型的基本架构。Swin Transformer(Shifted Window Transformer)是微软亚洲研究院在2021年提出的视觉Transformer模型,它通过引入层次化特征图和移位窗口机制,在计算机视觉任务中取得了显著效果。
环境配置要求:
- Python 3.7+
- PyTorch 1.8+
- CUDA 11.0+(如需GPU加速)
- torchvision 0.9+
# 安装基础依赖 pip install torch torchvision timmSwinT-Tiny作为Swin Transformer系列中最小的变体,其架构参数如下:
| 参数名称 | 值 |
|---|---|
| 嵌入维度 | 96 |
| 各阶段层数 | [2,2,6,2] |
| 多头注意力头数 | [3,6,12,24] |
| 窗口大小 | 7 |
| MLP扩展比 | 4 |
模型包含四个主要阶段(Stage),每个阶段通过Patch Merging降低分辨率并增加通道数,形成层次化特征图。这种设计使其能够像传统CNN一样处理多尺度特征。
2. 核心模块实现
2.1 Patch Embedding与Patch Merging
Patch Embedding负责将输入图像分割为不重叠的patch并嵌入到向量空间:
class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else None def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm: x = self.norm(x) return xPatch Merging用于下采样,在降低分辨率的同时增加通道数:
class PatchMerging(nn.Module): def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) x = self.norm(x) x = self.reduction(x) return x2.2 窗口注意力机制
窗口注意力(Window Attention)是Swin Transformer的核心创新,它在局部窗口内计算自注意力:
class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # 定义相对位置偏置表 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 生成相对位置索引 coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x2.3 Swin Transformer Block
Swin Transformer Block交替使用常规窗口注意力和移位窗口注意力:
class SwinTransformerBlock(nn.Module): def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: self.shift_size = 0 self.window_size = min(self.input_resolution) self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # 计算SW-MSA的注意力掩码 H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # 循环移位 if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # 窗口划分 x_windows = window_partition(shifted_x, self.window_size) x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # 合并窗口 attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # 反向循环移位 if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x3. 完整模型组装与训练技巧
3.1 Swin Transformer整体架构
将各个模块组合成完整的SwinT-Tiny模型:
class SwinTransformer(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # 分割图像为不重叠的patch self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution # 绝对位置嵌入 if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) # 随机深度衰减规则 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 构建层次化特征提取层 self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer( dim=int(embed_dim * 2 ** i_layer), input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x3.2 训练优化技巧
在实际训练Swin Transformer时,以下几个技巧可以显著提升模型性能:
- 学习率调度:使用余弦退火学习率调度器
- 权重衰减:适度的权重衰减(如0.05)
- 数据增强:MixUp、CutMix和RandAugment
- 标签平滑:系数设为0.1
- 梯度裁剪:最大梯度范数设为1.0
def get_optimizer(model, lr=5e-4, weight_decay=0.05): # 分层学习率设置 param_groups = [ {'params': [p for n, p in model.named_parameters() if 'norm' in n], 'weight_decay': 0.}, {'params': [p for n, p in model.named_parameters() if 'norm' not in n], 'weight_decay': weight_decay} ] return torch.optim.AdamW(param_groups, lr=lr) def get_scheduler(optimizer, num_epochs=300): return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)3.3 自定义输入分辨率处理
Swin Transformer对输入分辨率有一定要求,需要确保各阶段的分辨率能被窗口大小整除。以下是处理任意输入分辨率的实用函数:
def pad_to_window_size(x, window_size): B, C, H, W = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, pad_w, 0, pad_h)) return x, (H, W) class SwinTransformerWithPadding(nn.Module): def __init__(self, config): super().__init__() self.model = SwinTransformer(**config) self.window_size = config['window_size'] def forward(self, x): x, orig_size = pad_to_window_size(x, self.window_size) x = self.model(x) return x4. 模型微调与迁移学习
4.1 在自定义数据集上微调
当在特定任务上微调Swin Transformer时,建议采用以下策略:
- 分层学习率:浅层使用较小学习率,深层使用较大学习率
- 渐进式解冻:从最后一层开始逐步解冻前面层
- 选择性权重衰减:不对LayerNorm参数应用权重衰减
def fine_tune_setup(model, base_lr=1e-4, head_lr=1e-3, weight_decay=0.05): param_groups = [ {'params': [p for n, p in model.named_parameters() if 'head' in n], 'lr': head_lr}, {'params': [p for n, p in model.named_parameters() if 'head' not in n and 'norm' in n], 'lr': base_lr/2., 'weight_decay': 0.}, {'params': [p for n, p in model.named_parameters() if 'head' not in n and 'norm' not in n], 'lr': base_lr} ] return torch.optim.AdamW(param_groups, weight_decay=weight_decay)4.2 常见问题与解决方案
问题1:训练初期损失不下降
- 可能原因:学习率设置不当
- 解决方案:使用学习率探测(LR Finder)确定合适的学习率
问题2:验证集性能波动大
- 可能原因:批次大小过小
- 解决方案:使用梯度累积模拟大批次训练
问题3:GPU内存不足
- 可能原因:输入分辨率或批次大小过大
- 解决方案:
- 使用梯度检查点技术
- 降低批次大小但增加梯度累积步数
- 使用混合精度训练
# 混合精度训练示例 scaler = torch.cuda.amp.GradScaler() for inputs, targets in dataloader: optimizer.zero_grad() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5. 模型评估与性能分析
5.1 计算复杂度分析
Swin Transformer的计算复杂度主要来自以下几个方面:
- Patch Embedding:O(3×H×W×C)
- Window Attention:O(H×W×M²×C)(M为窗口大小)
- Patch Merging:O(H×W×C²)
下表展示了不同分辨率下的FLOPs对比:
| 输入分辨率 | Swin-T FLOPs | ViT-B/16 FLOPs |
|---|---|---|
| 224×224 | 4.5G | 17.6G |
| 384×384 | 13.1G | 55.4G |
| 512×512 | 25.3G | 98.6G |
5.2 实际部署考虑
在实际部署Swin Transformer时,可以考虑以下优化:
- TensorRT加速:将模型转换为TensorRT引擎
- 量化:使用8位整数或16位浮点量化
- 剪枝:移除不重要的注意力头或神经元
# 动态量化示例 model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8)对于移动端部署,可以考虑使用更轻量级的变体Swin-T-Nano,其参数配置为:
- 嵌入维度:64
- 各阶段层数:[2,2,6,2]
- 多头注意力头数:[2,4,8,16]