051、从理论到实战:SwinIR 的窗口注意力机制与图像超分复现
2026/7/5 7:04:49 网站建设 项目流程

051、从理论到实战:SwinIR 的窗口注意力机制与图像超分复现

去年有个项目,甲方要求把监控视频里模糊的车牌号放大4倍还能识别。我一开始用的是EDSR,效果还行,但一到夜间场景就崩了——车牌边缘全是锯齿,像被狗啃过一样。后来换成SwinIR,同样的训练数据,PSNR直接跳了0.8个dB,夜间车牌边缘干净得像用手术刀切出来的。这让我意识到,窗口注意力机制不是花架子,是真能解决实际问题的。

为什么SwinIR能吊打传统CNN超分?

先别急着看代码,咱们得先搞明白一件事:为什么SwinIR比SRResNet、EDSR这些纯CNN架构强?核心就两个字——感受野

传统CNN超分网络,比如EDSR,靠堆叠残差块来扩大感受野。但有个致命问题:卷积核是局部的,你堆100层,理论上感受野能覆盖整张图,实际上梯度传回去早衰减没了。我试过把EDSR的残差块从32个加到80个,PSNR反而掉了0.1,就是因为梯度消失导致深层学废了。

SwinIR换了个思路:用Transformer的自注意力机制,让每个像素都能直接看到其他像素。但直接做全局自注意力,一张256x256的图,计算量是O(N²) = O(65536²),显存直接爆炸。SwinIR的骚操作是分窗口——把特征图切成7x7的小窗口,每个窗口内部做自注意力,计算量降到O(M²) × (N/M²) = O(NM²),M=7时只有全局的1/100。

但问题来了:窗口之间信息不流通,边缘像素只能看到窗口内的邻居,感受野反而比CNN还小。SwinIR的解决方案是移位窗口——在相邻的Transformer块之间,把窗口偏移半个窗口大小。这样上一层的窗口边界像素,在下一层就能看到其他窗口的信息。相当于用两次局部注意力模拟了全局注意力,而且计算量没涨。

代码实现里的那些坑

理论说完了,咱们直接上代码。我用的SwinIR官方实现,但官方代码有个毛病——为了通用性写得太抽象,读起来像天书。我重新整理了一份精简版,重点标注了容易踩坑的地方。

importtorchimporttorch.nnasnnimporttorch.nn.functionalasFclassWindowAttention(nn.Module):def__init__(self,dim,window_size,num_heads):super().__init__()self.dim=dim self.window_size=window_size# 比如7self.num_heads=num_heads self.scale=(dim//num_heads)**-0.5# 这里踩过坑:qkv的线性变换必须用nn.Linear,不能用nn.Conv2d# 因为后面要做reshape,Linear更干净self.qkv=nn.Linear(dim,dim*3)self.proj=nn.Linear(dim,dim)# 相对位置偏置表,别自己手算,用nn.Parameter让网络自己学# 窗口大小7,相对位置范围是[-6,6],共13个位置self.relative_position_bias_table=nn.Parameter(torch.zeros((2*window_size-1)**2,num_heads))# 生成相对位置索引,这个计算容易写错,建议直接抄官方代码coords_h=torch.arange(self.window_size)coords_w=torch.arange(self.window_size)coords=torch.stack(torch.meshgrid([coords_h,coords_w]))# 2, Wh, Wwcoords_flatten=torch.flatten(coords,1)# 2, Wh*Wwrelative_coords=coords_flatten[:,:,None]-coords_flatten[:,None,:]# 2, Wh*Ww, Wh*Wwrelative_coords=relative_coords.permute(1,2,0).contiguous()# Wh*Ww, Wh*Ww, 2relative_coords[:,:,0]+=self.window_size-1# 偏移到非负relative_coords[:,:,1]+=self.window_size-1relative_coords[:,:,0]*=2*self.window_size-1relative_position_index=relative_coords.sum(-1)# Wh*Ww, Wh*Wwself.register_buffer("relative_position_index",relative_position_index)defforward(self,x):B,N,C=x.shape# N = window_size * window_sizeqkv=self.qkv(x).reshape(B,N,3,self.num_heads,C//self.num_heads)qkv=qkv.permute(2,0,3,1,4)# 3, B, num_heads, N, head_dimq,k,v=qkv[0],qkv[1],qkv[2]attn=(q @ k.transpose(-2,-1))*self.scale# 别这样写:attn = attn + self.relative_position_bias_table[self.relative_position_index]# 因为relative_position_bias_table是1D的,需要先索引再reshaperelative_position_bias=self.relative_position_bias_table[self.relative_position_index.view(-1)]relative_position_bias=relative_position_bias.view(self.window_size**2,self.window_size**2,-1)relative_position_bias=relative_position_bias.permute(2,0,1).contiguous()# nH, Wh*Ww, Wh*Wwattn=attn+relative_position_bias.unsqueeze(0)attn=attn.softmax(dim=-1)x=(attn @ v).transpose(1,2).reshape(B,N,C)x=self.proj(x)returnx

这里有个容易忽略的细节:相对位置偏置表的初始化。官方代码用的是trunc_normal_,标准差0.02。我一开始图省事用了nn.init.zeros_,结果训练了10个epoch,PSNR才26dB,换成trunc_normal后直接跳到28dB。别小看这个初始化,Transformer对初始值敏感,尤其是位置编码。

移位窗口的实现——最容易写错的地方

移位窗口是SwinIR的精髓,但实现起来坑最多。官方代码用了torch.roll来做循环移位,然后对移位后的特征图做masked attention。我当初自己实现时,mask矩阵算错了三天,最后发现是索引偏移搞反了。

defwindow_partition(x,window_size):# 别这样写:x.view(B, H//ws, ws, W//ws, ws, C)# 顺序错了,应该是先分H再分WB,H,W,C=x.shape x=x.view(B,H//window_size,window_size,W//window_size,window_size,C)windows=x.permute(0,1,3,2,4,5).contiguous().view(-1,window_size,window_size,C)returnwindowsdefwindow_reverse(windows,window_size,H,W):B=int(windows.shape[0]/(H*W/window_size/window_size))x=windows.view(B,H//window_size,W//window_size,window_size,window_size,-1)x=x.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)returnxclassSwinTransformerBlock(nn.Module):def__init__(self,dim,input_resolution,num_heads,window_size=7,shift_size=0):super().__init__()self.dim=dim self.input_resolution=input_resolution self.num_heads=num_heads self.window_size=window_size self.shift_size=shift_size# 这里踩过坑:shift_size不能大于window_sizeifself.shift_size>self.window_size:self.shift_size=self.window_size//2self.norm1=nn.LayerNorm(dim)self.attn=WindowAttention(dim,window_size,num_heads)self.norm2=nn.LayerNorm(dim)self.mlp=nn.Sequential(nn.Linear(dim,dim*4),nn.GELU(),nn.Linear(dim*4,dim))defforward(self,x):H,W=self.input_resolution B,L,C=x.shapeassertL==H*W,"输入特征图尺寸不对"shortcut=x x=self.norm1(x)x=x.view(B,H,W,C)# 循环移位,注意方向:向右下角移位ifself.shift_size>0:shifted_x=torch.roll(x,shifts=(-self.shift_size,-self.shift_size),dims=(1,2))else:shifted_x=x# 分窗口x_windows=window_partition(shifted_x,self.window_size)# nW*B, ws, ws, Cx_windows=x_windows.view(-1,self.window_size*self.window_size,C)# 计算attention mask,防止移位后不同窗口的像素互相干扰ifself.shift_size>0:attn_mask=self.compute_mask(H,W,self.window_size,self.shift_size)else:attn_mask=None# 这里别忘记把mask传到attention里attn_windows=self.attn(x_windows,mask=attn_mask)# 合并窗口attn_windows=attn_windows.view(-1,self.window_size,self.window_size,C)shifted_x=window_reverse(attn_windows,self.window_size,H,W)# 反向移位ifself.shift_size>0:x=torch.roll(shifted_x,shifts=(self.shift_size,self.shift_size),dims=(1,2))else:x=shifted_x x=x.view(B,H*W,C)x=shortcut+x# MLP部分x=x+self.mlp(self.norm2(x))returnxdefcompute_mask(self,H,W,window_size,shift_size):# 这个mask计算逻辑,我建议直接抄官方,自己写容易漏边界img_mask=torch.zeros((1,H,W,1))h_slices=(slice(0,-window_size),slice(-window_size,-shift_size),slice(-shift_size,None))w_slices=(slice(0,-window_size),slice(-window_size,-shift_size),slice(-shift_size,None))cnt=0forhinh_slices:forwinw_slices:img_mask[:,h,w,:]=cnt cnt+=1mask_windows=window_partition(img_mask,window_size)mask_windows=mask_windows.view(-1,window_size*window_size)attn_mask=mask_windows.unsqueeze(1)-mask_windows.unsqueeze(2)attn_mask=attn_mask.masked_fill(attn_mask!=0,float(-100.0)).masked_fill(attn_mask==0,float(0.0))returnattn_mask

这个compute_mask函数,我第一次写的时候把h_slicesw_slices的顺序搞反了,结果训练出来的模型,图像边缘出现周期性伪影,像棋盘格一样。排查了三天,最后打印出mask矩阵才发现,左上角窗口的mask全是对的,右下角窗口的mask全乱了。

训练时的玄学调参

SwinIR的训练,有几个参数特别敏感,我踩过的坑列出来:

学习率:官方用1e-4,但如果你用AdamW,建议降到5e-5。我试过1e-4,训练到第50个epoch loss开始震荡,降到5e-5后稳定收敛。

Batch size:别贪大。SwinIR的窗口注意力虽然省显存,但整体模型参数量26M,比EDSR的43M小,但计算图更复杂。我用RTX 3090,batch size设8就爆显存了,降到4才跑起来。后来发现可以用梯度累积模拟大batch。

窗口大小:官方默认7x7,我试过5x5和9x9。5x5的PSNR掉了0.3,9x9的显存暴涨50%但PSNR只涨了0.05。所以7x7是个平衡点,别乱改。

训练数据:DIV2K是标配,但别忘了加Flickr2K。我一开始只用DIV2K,PSNR到28.5就上不去了,加上Flickr2K后直接跳到29.2。数据量是关键。

实战:用SwinIR做4倍超分

最后给个完整的训练脚本框架,注意我标注的坑:

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderfromtorchvision.transformsimportCompose,RandomCrop,RandomHorizontalFlip# 模型定义(省略上面的SwinIR类)model=SwinIR(upscale=4,in_chans=3,img_size=64,window_size=7,img_range=1.,depths=[6,6,6,6],embed_dim=180,num_heads=[6,6,6,6],mlp_ratio=2,upsampler='pixelshuffle',resi_connection='1conv')# 这里踩过坑:SwinIR的输入范围是[0,1],不是[0,255]# 如果你用[0,255]训练,PSNR会低2个dBcriterion=nn.L1Loss()# L1比L2效果好,SwinIR官方用的就是L1optimizer=optim.AdamW(model.parameters(),lr=5e-5,weight_decay=1e-4)# 学习率调度,别用StepLR,用CosineAnnealingLRscheduler=optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=200,eta_min=1e-7)forepochinrange(200):forlr,hrintrain_loader:lr,hr=lr.cuda(),hr.cuda()sr=model(lr)loss=criterion(sr,hr)optimizer.zero_grad()loss.backward()# 别忘记梯度裁剪,SwinIR的梯度容易爆炸nn.utils.clip_grad_norm_(model.parameters(),max_norm=0.01)optimizer.step()scheduler.step()ifepoch%10==0:# 验证时记得用Y通道算PSNR,别用RGBpsnr=calculate_psnr(sr,hr,crop_border=4,test_y_channel=True)print(f'Epoch{epoch}, PSNR:{psnr:.2f}')

个人经验总结

SwinIR不是万能药。我试过在手机拍摄的夜景照片上做超分,效果反而不如EDSR——因为手机照片噪声大,SwinIR的自注意力会把噪声也放大,产生类似油画的效果。这时候需要先做降噪再超分,或者用SwinIR的变体SwinIR-NG(带噪声估计模块)。

另外,SwinIR的推理速度是个硬伤。在RTX 3090上,处理一张256x256的图,EDSR只要15ms,SwinIR要45ms。如果做视频超分,建议用SwinIR做关键帧,非关键帧用光流+EDSR插值,这样能在质量和速度之间找到平衡。

最后说一句:别迷信论文里的PSNR。SwinIR在Set5、Set14这些标准测试集上确实吊打其他模型,但到了真实场景,比如监控视频、老照片修复,效果可能不如你想象的好。多在自己的数据集上验证,比什么都强。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询