从零构建ST-UNet遥感分割模型:环境配置到结果可视化的全流程实战
遥感图像分割是计算机视觉领域的重要应用方向,而ST-UNet作为结合Swin Transformer和UNet的创新架构,在ISPRS Vaihingen等数据集上展现了卓越性能。本文将带您从零开始,完整复现这一前沿模型。
1. 环境准备与数据获取
1.1 基础环境配置
建议使用Anaconda创建独立的Python环境,避免依赖冲突。以下是关键组件的安装命令:
conda create -n stunet python=3.8 conda activate stunet pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install opencv-python pillow matplotlib tqdm tensorboard注意:CUDA版本需与显卡驱动匹配,可通过
nvidia-smi查询兼容的CUDA版本
1.2 数据集准备
ISPRS Vaihingen数据集包含33幅航空影像,建议按以下结构组织数据:
Vaihingen/ ├── train/ │ ├── images/ # 原始图像 │ └── labels/ # 标注掩码 └── val/ ├── images/ └── labels/数据集预处理的关键步骤:
- 将图像裁剪为256×256 patches
- 实现数据增强策略:
- 随机水平/垂直翻转(概率0.5)
- 随机旋转(0-90度)
- 颜色抖动(亮度、对比度各0.2)
2. 核心模块实现
2.1 Swin Transformer Block实现
ST-UNet的核心创新在于将Swin Transformer与传统UNet结合。以下是W-MSA(Window-based Multi-head Self-Attention)的关键实现:
class WindowAttention(nn.Module): def __init__(self, dim, window_size, num_heads): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 # 相对位置偏置表 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 生成相对位置索引 coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3) self.proj = nn.Linear(dim, dim) def forward(self, x, mask=None): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) attn = attn + relative_position_bias.permute(2, 0, 1).unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x2.2 空间交互模块(SIM)实现
SIM模块通过垂直和水平注意力增强空间信息交互:
class SIM(nn.Module): def __init__(self, in_channels): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, in_channels//2, 3, padding=2, dilation=2), nn.BatchNorm2d(in_channels//2), nn.GELU() ) self.conv_v = nn.Conv2d(in_channels//2, in_channels//2, 1) self.conv_h = nn.Conv2d(in_channels//2, in_channels//2, 1) self.conv_final = nn.Sequential( nn.Conv2d(in_channels//2, in_channels, 1), nn.BatchNorm2d(in_channels), nn.GELU() ) def forward(self, x): B, C, H, W = x.shape feat = self.conv(x) # 垂直注意力 v_feat = F.avg_pool2d(feat, kernel_size=(1, W)) v_feat = self.conv_v(v_feat) v_feat = v_feat.expand(-1, -1, H, W) # 水平注意力 h_feat = F.avg_pool2d(feat, kernel_size=(H, 1)) h_feat = self.conv_h(h_feat) h_feat = h_feat.expand(-1, -1, H, W) # 空间注意力融合 attn = torch.sigmoid(v_feat * h_feat) out = self.conv_final(feat * attn) return out3. 模型训练与调优
3.1 损失函数配置
ST-UNet采用Dice Loss和交叉熵损失的组合:
class DiceCELoss(nn.Module): def __init__(self, weight=None, size_average=True): super().__init__() self.ce = nn.CrossEntropyLoss(weight=weight) def forward(self, inputs, targets, smooth=1): # 交叉熵损失 ce_loss = self.ce(inputs, targets) # Dice损失 inputs = F.softmax(inputs, dim=1) targets_onehot = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2) intersection = (inputs * targets_onehot).sum(dim=(2,3)) union = inputs.sum(dim=(2,3)) + targets_onehot.sum(dim=(2,3)) dice_loss = 1 - (2. * intersection + smooth)/(union + smooth) dice_loss = dice_loss.mean() return ce_loss + dice_loss3.2 训练策略优化
建议采用以下训练策略提升模型性能:
| 策略 | 参数设置 | 作用 |
|---|---|---|
| 学习率调度 | Poly策略 (power=0.9) | 平滑衰减学习率 |
| 优化器 | SGD (momentum=0.9) | 稳定训练过程 |
| 权重衰减 | 1e-4 | 防止过拟合 |
| 批量大小 | 8 (根据显存调整) | 平衡速度与稳定性 |
| 最大epoch | 100 | 充分训练 |
提示:使用混合精度训练可显著减少显存占用,添加
scaler = torch.cuda.amp.GradScaler()包装loss
4. 结果可视化与分析
4.1 训练监控
使用TensorBoard监控训练过程:
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for epoch in range(epochs): # ...训练代码... writer.add_scalar('Loss/train', loss.item(), epoch) writer.add_scalar('mIoU/val', miou, epoch)关键指标监控建议:
- 训练损失曲线(应平稳下降)
- 验证集mIoU(反映泛化能力)
- 学习率变化曲线
- 显存使用情况
4.2 预测结果可视化
实现语义分割结果可视化函数:
def visualize_prediction(image, mask, pred, save_path): """ image: 原始图像 (H,W,3) mask: 真实标注 (H,W) pred: 模型预测 (H,W) """ plt.figure(figsize=(18,6)) # 原始图像 plt.subplot(1,3,1) plt.imshow(image) plt.title("Input Image") # 真实标注 plt.subplot(1,3,2) plt.imshow(mask, cmap='jet') plt.title("Ground Truth") # 预测结果 plt.subplot(1,3,3) plt.imshow(pred, cmap='jet') plt.title("Prediction") plt.savefig(save_path) plt.close()典型问题解决方案:
- 显存不足:减小batch size或使用梯度累积
- 类别不平衡:在损失函数中添加类别权重
- 训练震荡:适当降低初始学习率
- 细节丢失:增加FCM模块的通道数
在Vaihingen数据集上的预期性能:
- mIoU: ≥78%
- 平均F1分数: ≥85%
- 推理速度: ~15fps (RTX 2080Ti)