Pytorch图像去噪实战(十四):条件扩散模型图像去噪,让Diffusion根据带噪图恢复干净图
一、问题场景:普通Diffusion能生成图,但不能直接修复指定图片
前面我们实现了 DDPM 和 DDIM。
但如果你仔细看,会发现之前的采样方式是:
从纯噪声开始生成图像这更像是生成任务。
而真实图像去噪任务通常是:
给定一张带噪图,输出它对应的干净图也就是说,我们不是要随机生成图片,而是要修复指定图片。
这时普通无条件Diffusion就不够用了,需要引入:
条件扩散模型 Conditional Diffusion
二、条件扩散去噪的核心思想
普通Diffusion输入:
x_t, t条件Diffusion输入:
x_t, noisy_condition, t其中:
- x_t:扩散过程中的 noisy clean image
- noisy_condition:真实带噪图
- t:时间步
模型学习:
predict noise from x_t with condition也就是让模型在反向去噪时参考原始带噪图。
三、为什么需要condition?
如果没有condition,模型生成的是随机干净图,不一定和输入图片内容一致。
加入condition后,模型知道:
- 图像结构是什么
- 边缘在哪里
- 文字位置在哪里
- 物体轮廓在哪里
因此它可以围绕输入图像做恢复,而不是凭空生成。
四、工程结构
conditional_diffusion_denoise/ ├── data/ │ └── train/ ├── models/ │ └── conditional_unet.py ├── diffusion/ │ └── ddpm.py ├── dataset.py ├── train.py ├── infer.py └── utils.py五、数据集构造
训练时我们有 clean 图,然后人工加噪得到 condition。
importosimportrandomimporttorchfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassConditionalDenoiseDataset(Dataset):def__init__(self,root_dir,image_size=64):self.paths=[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((".jpg",".png",".jpeg"))]self.transform=transforms.Compose([transforms.Resize((image_size,image_size)),transforms.ToTensor()])def__len__(self):returnlen(self.paths)def__getitem__(self,index):clean=Image.open(self.paths[index]).convert("L")clean=self.transform(clean)sigma=random.choice([15,25,35,50])noise=torch.randn_like(clean)*sigma/255.0noisy_condition=torch.clamp(clean+noise,0.0,1.0)returnnoisy_condition,clean六、条件UNet模型
核心改动非常简单:
把 x_t 和 noisy_condition 在通道维度拼接。
如果是灰度图:
x_t: 1通道 condition: 1通道 concat后: 2通道models/conditional_unet.py
importtorchimporttorch.nnasnnclassTimeEmbedding(nn.Module):def__init__(self,dim):super().__init__()self.net=nn.Sequential(nn.Linear(1,dim),nn.SiLU(),nn.Linear(dim,dim))defforward(self,t):t=t.float().view(-1,1)/1000.0returnself.net(t)classResBlock(nn.Module):def__init__(self,in_channels,out_channels,time_dim):super().__init__()self.conv1=nn.Conv2d(in_channels,out_channels,3,padding=1)self.conv2=nn.Conv2d(out_channels,out_channels,3,padding=1)self.time_proj=nn.Linear(time_dim,out_channels)self.shortcut=nn.Identity()ifin_channels!=out_channels:self.shortcut=nn.Conv2d(in_channels,out_channels,1)self.act=nn.SiLU()defforward(self,x,t_emb):h=self.act(self.conv1(x))time=self.time_proj(t_emb).view(x.size(0),-1,1,1)h=h+time h=self.conv2(self.act(h))returnh+self.shortcut(x)classConditionalUNet(nn.Module):def__init__(self,image_channels=1,base=64,time_dim=128):super().__init__()self.time_mlp=TimeEmbedding(time_dim)in_channels=image_channels*2self.down1=ResBlock(in_channels,base,time_dim)self.down2=ResBlock(base,base*2,time_dim)self.pool=nn.MaxPool2d(2)self.mid=ResBlock(base*2,base*2,time_dim)self.up=nn.ConvTranspose2d(base*2,base,2,2)self.up_block=ResBlock(base*2,base,time_dim)self.out=nn.Conv2d(base,image_channels,3,padding=1)defforward(self,xt,condition,t):t_emb=self.time_mlp(t)x=torch.cat([xt,condition],dim=1)d1=self.down1(x,t_emb)d2=self.down2(self.pool(d1),t_emb)mid=self.mid(d2,t_emb)u=self.up(mid)u=torch.cat([u,d1],dim=1)u=self.up_block(u,t_emb)returnself.out(u)七、训练代码
importtorchfromtorch.utils.dataimportDataLoaderfromdatasetimportConditionalDenoiseDatasetfromdiffusion.ddpmimportDDPMfrommodels.conditional_unetimportConditionalUNetdeftrain():device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")dataset=ConditionalDenoiseDataset("data/train",image_size=64)loader=DataLoader(dataset,batch_size=16,shuffle=True,num_workers=4)model=ConditionalUNet().to(device)diffusion=DDPM(timesteps=1000,beta_start=1e-4,beta_end=0.02,device=device)optimizer=torch.optim.AdamW(model.parameters(),lr=2e-4)criterion=torch.nn.MSELoss()forepochinrange(1,101):model.train()total_loss=0forcondition,cleaninloader:condition=condition.to(device)clean=clean.to(device)t=torch.randint(0,diffusion.timesteps,(clean.size(0),),device=device)xt,noise=diffusion.q_sample(clean,t)pred_noise=model(xt,condition,t)loss=criterion(pred_noise,noise)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_loss+=loss.item()print(f"Epoch{epoch}, Loss:{total_loss/len(loader):.6f}")ifepoch%10==0:torch.save(model.state_dict(),f"conditional_diffusion_epoch_{epoch}.pth")if__name__=="__main__":train()八、推理代码
推理时输入一张真实 noisy image 作为 condition。
importtorchfromPILimportImageimporttorchvision.transformsastransformsimporttorchvision.utilsasvutilsfromdiffusion.ddpmimportDDPMfrommodels.conditional_unetimportConditionalUNet@torch.no_grad()definfer():device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=ConditionalUNet().to(device)model.load_state_dict(torch.load("conditional_diffusion_epoch_100.pth",map_location=device))model.eval()diffusion=DDPM(timesteps=1000,beta_start=1e-4,beta_end=0.02,device=device)img=Image.open("test_noisy.png").convert("L")transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor()])condition=transform(img).unsqueeze(0).to(device)x=torch.randn_like(condition)fortinreversed(range(diffusion.timesteps)):batch_t=torch.full((1,),t,device=device,dtype=torch.long)pred_noise=model(x,condition,batch_t)beta=diffusion.betas[t]alpha=diffusion.alphas[t]alpha_bar=diffusion.alpha_bars[t]x=(1/torch.sqrt(alpha))*(x-(beta/torch.sqrt(1-alpha_bar))*pred_noise)ift>0:x=x+torch.sqrt(beta)*torch.randn_like(x)x=torch.clamp(x,0.0,1.0)vutils.save_image(x.cpu(),"conditional_denoised.png")if__name__=="__main__":infer()九、为什么条件图不能直接作为初始x?
很多人第一次写条件扩散时,会想:
直接从 noisy image 开始反向去噪不就行了?但标准条件扩散里,反向过程的变量 x 是目标 clean 的扩散状态,而 noisy image 是条件信息。
两者角色不同:
- x:当前正在生成的 clean image 状态
- condition:引导恢复的输入图
如果混在一起,模型训练和推理分布会不一致。
十、和普通UNet去噪相比有什么优势?
普通UNet:
noisy -> clean条件Diffusion:
noise state + noisy condition -> clean distribution优势在于:
- 更适合复杂噪声
- 可以生成更自然细节
- 对强噪声恢复潜力更高
缺点也明显:
- 训练更慢
- 推理更慢
- 工程复杂度更高
十一、踩坑记录
坑1:condition没有拼接进模型
如果模型只输入 xt 和 t,那就是无条件生成,不是图像去噪。
坑2:condition和clean尺寸不一致
训练时 condition 和 clean 必须尺寸一致。
建议在 dataset 中统一 resize。
坑3:采样太慢
条件Diffusion同样有1000步采样问题。
建议后续结合DDIM。
十二、适合收藏总结
条件Diffusion去噪流程
- 从clean构造noisy condition
- 对clean执行扩散加噪
- 模型输入 xt + condition + t
- 模型预测noise
- 推理时用condition引导反向去噪
避坑清单
- condition必须输入模型
- clean和condition尺寸一致
- x和condition角色不要混
- 推理成本较高
- 建议结合DDIM加速
十三、优化建议
可以继续做:
- 条件DDIM采样
- 加强UNet结构
- 使用Restormer作为条件网络
- 支持RGB图像
- 用真实噪声数据微调
结尾总结
条件扩散模型把Diffusion从“随机生成图像”推进到“指定图像恢复”。
它的核心价值是:
既保留扩散模型强大的生成能力,又让模型受输入带噪图约束。
如果你要把Diffusion用于真正的图像去噪任务,条件扩散是必须掌握的一步。
下一篇预告
Pytorch图像去噪实战(十五):彩色RGB图像去噪实战,从灰度模型升级到真实图片处理