用PyTorch Lightning和LaMa搞定图像修复:我的big-lama自定义数据集训练全记录
在数字图像处理领域,图像修复技术正经历着革命性的变化。作为一名长期关注计算机视觉发展的技术实践者,我最近成功将LaMa框架的big-lama模型应用于自定义数据集训练,实现了对特定类型图像(如家族老照片和动漫截图)的高质量修复。这篇文章将完整分享从环境准备到模型调优的全流程实战经验,特别适合那些希望将前沿AI模型落地到具体应用场景的开发者和研究者。
1. 环境搭建与数据准备
1.1 基础环境配置
开始之前,需要确保具备以下基础环境:
- 硬件要求:建议至少配备16GB内存的NVIDIA GPU(如RTX 3080及以上),因为big-lama模型对显存要求较高
- Python环境:推荐使用Python 3.8+,并创建独立的conda环境:
conda create -n lama_env python=3.8 conda activate lama_env- 核心依赖安装:
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install pytorch-lightning==1.5.0 pip install omegaconf opencv-python注意:PyTorch Lightning的版本需要严格匹配,不同版本间的API变化可能导致训练失败
1.2 自定义数据集构建
对于图像修复任务,数据集的质量直接影响最终效果。我的自定义数据集包含约5000张老照片和2000张动漫截图,按以下结构组织:
my_dataset/ ├── train/ │ ├── images/ # 原始图像 │ └── masks/ # 对应掩码 └── validation/ ├── images/ └── masks/关键数据准备要点:
图像预处理:
- 统一调整为512x512分辨率
- 转换为RGB格式(即使原始是灰度图)
- 标准化像素值到[0,1]范围
掩码生成规则:
- 破损区域用白色(1,1,1)表示
- 完好区域用黑色(0,0,0)表示
- 使用
cv2.threshold确保二值化准确
import cv2 def generate_mask(image_path): img = cv2.imread(image_path) gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) _, mask = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY) return mask2. 模型配置与关键修改
2.1 源码调整与补丁应用
直接从GitHub克隆LaMa官方仓库后,需要进行几处关键修改才能适配自定义训练:
修改点一:Checkpoint加载容错
# pytorch_lightning/trainer/connectors/checkpoint_connector.py try: self.restore_training_state(checkpoint) except KeyError: rank_zero_warn("Checkpoint仅包含模型参数,无法恢复训练状态")修改点二:损失函数调整
# saicinpainting/training/trainers/base.py if self.config.losses.get("sege_pl", {"weight": 0})['weight'] > 0: self.loss_sege_pl = ResNetPL(**self.config.losses.sege_pl)提示:这些修改主要解决预训练模型与自定义训练间的兼容性问题
2.2 配置文件详解
big-lama的配置文件(big-lama.yaml)需要重点关注以下参数:
| 参数项 | 推荐值 | 说明 |
|---|---|---|
data.batch_size | 8-12 | 根据GPU显存调整 |
trainer.max_epochs | 100 | 足够收敛的epoch数 |
optimizer.lr | 3e-5 | 学习率不宜过大 |
losses.sege_pl.weight | 0.1 | 感知损失的权重系数 |
我的实际启动命令如下:
python bin/train.py -cn big-lama location=my_dataset \ data.batch_size=10 \ +trainer.kwargs.resume_from_checkpoint=./big-lama-with-discr-remove-loss_segm_pl.ckpt3. 训练过程监控与调优
3.1 训练指标解读
使用TensorBoard监控训练过程时,需要特别关注以下指标:
- val/loss_total:验证集总损失,判断模型收敛情况
- train/psnr:峰值信噪比,衡量修复质量
- grad_norm:梯度范数,检测训练稳定性
典型训练曲线应呈现:
- 前20个epoch快速下降
- 中间阶段平稳波动
- 后期微调收敛
3.2 常见问题解决
在实际训练中遇到的主要挑战及解决方案:
显存不足(OOM)
- 降低
batch_size到8 - 启用
gradient_checkpointing - 使用
amp_level=O2混合精度
- 降低
损失震荡
- 调整学习率到1e-5
- 增加
warmup_steps到1000 - 检查数据标注质量
过拟合
- 添加更强的数据增强:
A.Compose([ A.RandomRotate90(), A.ColorJitter(0.2, 0.2, 0.2), A.GaussianBlur(blur_limit=(3,7)) ])
- 添加更强的数据增强:
4. 模型应用与效果评估
4.1 推理接口封装
为方便实际应用,我封装了以下预测函数:
def predict(image_path, mask_path, model, device='cuda'): image = cv2.imread(image_path) mask = cv2.imread(mask_path) # 预处理 image = (image.astype(np.float32) / 255).transpose(2,0,1) mask = (mask.astype(np.float32) / 255).transpose(2,0,1) with torch.no_grad(): inpainted = model(torch.from_numpy(image).to(device), torch.from_numpy(mask).to(device)) return inpainted.cpu().numpy().transpose(1,2,0)4.2 效果对比分析
在不同类型图像上的修复效果评估:
| 图像类型 | PSNR | SSIM | 主观评价 |
|---|---|---|---|
| 老照片 | 28.7 | 0.91 | 纹理保持良好 |
| 动漫图 | 31.2 | 0.95 | 边缘锐利清晰 |
| 自然风景 | 26.5 | 0.88 | 色彩过渡自然 |
实际案例展示(文字描述):
- 一张1940年的家族合影,右上角有约30%的霉斑损坏,修复后细节层次分明
- 动漫角色左眼部分完全缺失,模型成功重建了符合原画风的眼部特征
- 风景照中的大面积水渍被自然去除,天空云层过渡毫无违和
整个项目最耗时的部分其实是数据准备阶段,约占全部时间的60%。而模型训练本身在单卡3090上大约需要36小时完成100个epoch。最终的模型大小约1.2GB,可以部署在消费级GPU上实时运行。