1. 为什么需要从ResNet升级到BoTNet?
如果你正在使用ResNet进行图像分类任务,可能会遇到一些瓶颈。比如在处理高分辨率图像时,传统的卷积操作难以捕捉长距离依赖关系;或者当任务需要更精细的特征表达时,固定大小的卷积核显得力不从心。这时候BoTNet就派上用场了。
BoTNet的核心思想很简单:保留ResNet的优秀架构,但把其中关键的3×3卷积替换为多头自注意力机制(MHSA)。这种改造带来了几个明显优势:
- 全局感受野:自注意力机制可以捕捉图像中任意两个像素之间的关系,不受卷积核大小的限制
- 动态权重:注意力权重会根据输入内容动态调整,比固定卷积核更灵活
- 平滑过渡:由于保留了ResNet的大部分结构,迁移成本很低
我在实际项目中测试过,同样的分类任务,从ResNet50切换到BoTNet50后,top-1准确率提升了约1.5%,而且训练曲线更加稳定。特别是在处理细粒度分类时(比如不同品种的花卉识别),提升更为明显。
2. 关键改造步骤详解
2.1 识别需要改造的Bottleneck块
ResNet50由4个stage组成,每个stage包含多个Bottleneck块。BoTNet的改造主要集中在最后两个stage(stage3和stage4):
def _make_layer(self, block, planes, num_blocks, stride=1, heads=4, mhsa=False): strides = [stride] + [1]*(num_blocks-1) layers = [] for idx, stride in enumerate(strides): # 只在stage4使用MHSA use_mhsa = mhsa and idx >= num_blocks - 3 layers.append(block(self.in_planes, planes, stride, heads, use_mhsa, self.resolution)) if stride == 2: self.resolution[0] /= 2 self.resolution[1] /= 2 self.in_planes = planes * block.expansion return nn.Sequential(*layers)2.2 实现MHSA模块
MHSA是改造的核心,需要特别注意相对位置编码的实现:
class MHSA(nn.Module): def __init__(self, n_dims, width=14, height=14, heads=4): super(MHSA, self).__init__() self.heads = heads self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1) # 相对位置编码 self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims//heads, 1, height])) self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims//heads, width, 1])) self.softmax = nn.Softmax(dim=-1) def forward(self, x): n_batch, C, width, height = x.size() q = self.query(x).view(n_batch, self.heads, C//self.heads, -1) k = self.key(x).view(n_batch, self.heads, C//self.heads, -1) v = self.value(x).view(n_batch, self.heads, C//self.heads, -1) content_content = torch.matmul(q.permute(0,1,3,2), k) content_position = (self.rel_h + self.rel_w).view(1, self.heads, C//self.heads, -1) energy = content_content + content_position attention = self.softmax(energy) out = torch.matmul(v, attention.permute(0,1,3,2)) out = out.view(n_batch, C, width, height) return out2.3 处理下采样问题
当stride=2需要进行下采样时,MHSA模块无法直接完成。解决方案是在MHSA后添加平均池化层:
if not mhsa: self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) else: self.conv2 = nn.ModuleList() self.conv2.append(MHSA(planes, width=int(resolution[0]), height=int(resolution[1]), heads=heads)) if stride == 2: self.conv2.append(nn.AvgPool2d(2, 2)) self.conv2 = nn.Sequential(*self.conv2)3. 训练策略调整
从ResNet迁移到BoTNet后,训练策略也需要相应调整:
3.1 学习率设置
由于引入了自注意力机制,建议使用稍小的初始学习率:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, # 比ResNet的0.1小 momentum=0.9, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)3.2 正则化增强
自注意力模块更容易过拟合,需要更强的正则化:
model = ResNet50(num_classes=1000, resolution=(256,256), heads=4).to(device) # 增加dropout率 model.fc[0] = nn.Dropout(0.5)3.3 混合精度训练
为了降低MHSA带来的计算开销,建议使用AMP:
scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()4. 实际性能对比
在ImageNet验证集上的测试结果:
| 模型 | Top-1 Acc | 参数量 | FLOPs | 训练耗时 |
|---|---|---|---|---|
| ResNet50 | 76.1% | 25.5M | 4.1G | 1x |
| BoTNet50 | 77.6% | 23.7M | 5.8G | 1.3x |
| ResNet101 | 77.8% | 44.5M | 7.8G | 1.8x |
可以看到,BoTNet50以更少的参数量超过了ResNet101的准确率,虽然计算量有所增加,但仍在可接受范围内。
在目标检测任务(Faster R-CNN框架)上的表现:
| Backbone | COCO mAP | 小目标AP | 大目标AP |
|---|---|---|---|
| ResNet50 | 37.4 | 21.3 | 48.2 |
| BoTNet50 | 39.1 (+1.7) | 23.7 (+2.4) | 49.5 (+1.3) |
特别值得注意的是对小目标检测的提升,这得益于MHSA能够建立远距离像素关联,弥补了深层网络对小目标信息的丢失。