在基础UNet网络上改进,主要是在网络结构中添加残差模块和注意力模块,主要通过修改编码器-解码器中的卷积块和跳跃连接实现。以下是具体实现方案,涵盖关键修改位置、代码示例和效果对比。
1. 问题解构:修改目标与位置
基础UNet由对称的编码器(下采样)、解码器(上采样)和跳跃连接构成。修改主要针对:
- 残差模块:替换编码器和解码器中的标准卷积块,缓解梯度消失,提升特征复用能力。
- 注意力模块:嵌入在跳跃连接或解码器中,使模型聚焦于重要特征区域(如番茄果实),抑制无关背景。
2. 方案推演:模块设计与集成策略
2.1 残差模块集成
残差模块的核心是恒等映射与卷积路径的相加。在UNet中,常用残差卷积块(Residual Convolutional Block)替换标准双卷积块。
基础残差块代码示例(PyTorch):
import torch import torch.nn as nn class ResidualBlock(nn.Module): """基础残差块,包含两个卷积层和跳跃连接""" def __init__(self, in_channels, out_channels): super().__init__() # 主路径:两个卷积层 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) # 跳跃连接:如果输入输出通道数不同,需用1x1卷积调整 self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = self.shortcut(x) # 恒等映射或通道调整 out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity # 残差连接 out = self.relu(out) return out集成到UNet编码器层:
将原UNet编码器中连续的两个Conv2d + ReLU替换为ResidualBlock。例如:
# 原基础UNet编码器部分 self.enc1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) # 修改为残差块版本 self.enc1 = ResidualBlock(3, 64) # 直接使用残差块替代2.2 注意力模块集成
注意力模块通常加在跳跃连接上,对编码器特征进行重加权后再与解码器特征融合。常用空间注意力(Spatial Attention)或通道注意力(Channel Attention)。
空间注意力模块示例:
class SpatialAttention(nn.Module): """空间注意力:学习特征图中每个位置的重要性""" def __init__(self, kernel_size=7): super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2) self.sigmoid = nn.Sigmoid() def forward(self, x): # 沿通道维度计算均值和最大值 avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) concat = torch.cat([avg_out, max_out], dim=1) # 生成空间注意力权重图 attention = self.sigmoid(self.conv(concat)) return x * attention # 特征重加权集成到UNet跳跃连接:
在将编码器特征与解码器特征拼接前,先对编码器特征施加注意力。
class AttentionUNetSkipConnection(nn.Module): """带注意力的跳跃连接""" def __init__(self, in_channels): super().__init__() self.attention = SpatialAttention() # 可选:添加1x1卷积对齐通道数(如果编码器和解码器通道数不同) self.conv1x1 = nn.Conv2d(in_channels, in_channels//2, kernel_size=1) if in_channels != in_channels//2 else nn.Identity() def forward(self, encoder_feat, decoder_feat): # 对编码器特征进行注意力加权 attended_encoder = self.attention(encoder_feat) # 通道对齐(如果需要) attended_encoder = self.conv1x1(attended_encoder) # 与解码器特征拼接(沿通道维度) return torch.cat([attended_encoder, decoder_feat], dim=1)然后将原UNet中直接的torch.cat([encoder_feat, decoder_feat])替换为上述模块。
2.3 完整修改示例:Res-Attention UNet
结合残差块和注意力跳跃连接,构建改进的UNet。
class ResAttentionUNet(nn.Module): def __init__(self, in_channels=3, num_classes=5): # 假设番茄分割有5类 super().__init__() # 编码器部分:使用残差块 self.enc1 = ResidualBlock(in_channels, 64) self.pool1 = nn.MaxPool2d(2) self.enc2 = ResidualBlock(64, 128) self.pool2 = nn.MaxPool2d(2) self.enc3 = ResidualBlock(128, 256) self.pool3 = nn.MaxPool2d(2) self.enc4 = ResidualBlock(256, 512) self.pool4 = nn.MaxPool2d(2) # 瓶颈层 self.bottleneck = ResidualBlock(512, 1024) # 解码器部分:上采样 + 残差块 + 注意力跳跃连接 self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.att4 = AttentionUNetSkipConnection(512) # 注意力跳跃连接 self.dec4 = ResidualBlock(1024, 512) # 拼接后通道数翻倍 self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.att3 = AttentionUNetSkipConnection(256) self.dec3 = ResidualBlock(512, 256) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.att2 = AttentionUNetSkipConnection(128) self.dec2 = ResidualBlock(256, 128) self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.att1 = AttentionUNetSkipConnection(64) self.dec1 = ResidualBlock(128, 64) # 最终分类层 self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1) def forward(self, x): # 编码路径 e1 = self.enc1(x) e2 = self.enc2(self.pool1(e1)) e3 = self.enc3(self.pool2(e2)) e4 = self.enc4(self.pool3(e3)) # 瓶颈 b = self.bottleneck(self.pool4(e4)) # 解码路径(带注意力跳跃连接) d4 = self.up4(b) d4 = self.att4(e4, d4) # 注意力融合特征 d4 = self.dec4(d4) d3 = self.up3(d4) d3 = self.att3(e3, d3) d3 = self.dec3(d3) d2 = self.up2(d3) d2 = self.att2(e2, d2) d2 = self.dec2(d2) d1 = self.up1(d2) d1 = self.att1(e1, d1) d1 = self.dec1(d1) return self.final_conv(d1)3. 效果对比与选择建议
不同模块组合在番茄分割任务中的预期效果对比如下:
| 模型变体 | 核心修改 | 优点 | 潜在缺点 | 适用场景 |
|---|---|---|---|---|
| 基础UNet | 标准卷积+跳跃连接 | 结构简单,训练快,小数据集友好 | 深层特征易丢失,对细小目标(如番茄病害斑点)不敏感 | 数据量少、目标明显的初步实验 |
| Res-UNet | 编码器/解码器用残差块替换 | 缓解梯度消失,特征复用能力强,训练更稳定,提升深度网络性能 | 参数稍增,可能过拟合 | 数据集较大,需要训练深层网络 |
| Attention UNet | 跳跃连接加空间/通道注意力 | 聚焦关键区域(果实),抑制背景(枝叶),提升目标边界精度 | 计算量增加,训练时间稍长 | 目标与背景复杂、边界模糊(如重叠番茄) |
| Res-Attention UNet | 同时集成残差和注意力 | 兼具两者优点:特征提取强+聚焦关键区域,通常获得最佳精度 | 模型复杂度最高,需更多数据防过拟合 | 追求高精度,数据充足,计算资源允许 |
实施建议:
- 渐进修改:先单独测试残差或注意力模块,再组合,便于定位性能变化来源。
- 通道对齐:在注意力跳跃连接中,若编码器和解码器特征通道数不同,务必使用
1x1卷积调整。 - 位置选择:注意力模块加在所有跳跃连接上开销大,可仅加在深层(如
att4、att3),因为深层特征语义信息更强。 - 预训练权重:若使用ResNet等预训练编码器,可快速获得良好初始特征,加速收敛。
通过上述方法,可根据番茄分割任务的具体需求(如精度、速度、数据量)灵活修改基础UNet,平衡性能与效率。
参考来源
- UNet网络在图像去模糊方向的应用
- 从ResNet50到Res-Unet:详解残差模块融合与Keras代码实战
- MIMO-UNet学习
- 【大作业-27】Unet系列模型在自己医学数据集上的使用(unet、unet++、r2net、attention unet以及unet的改进)
- 基于SwinTransformer+UNet的遥感图像语义分割
- 基于改进UNET的遥感图像分割系统