从NLP跨界CV:手把手教你用PyTorch复现Vision Transformer (ViT) 图像分类
当Transformer在自然语言处理领域大放异彩时,计算机视觉研究者们开始思考:这种基于自注意力机制的架构能否同样颠覆图像识别领域?2020年,Vision Transformer (ViT) 的出现给出了肯定答案。本文将带你从零开始,用PyTorch实现这一开创性模型,体验如何将图像转化为"视觉词汇"的奇妙过程。
1. ViT核心原理与设计思路
传统卷积神经网络(CNN)通过局部感受野逐层提取特征,而ViT则采用全局视角处理图像——它将输入图片分割为16x16的"视觉词汇块"(patches),每个块经过线性投影后成为Transformer可处理的序列元素。这种设计带来了三大关键创新:
- 图像序列化:将2D图像转换为1D令牌序列
- 位置编码:通过可学习的位置嵌入保留空间信息
- 纯Transformer架构:完全摒弃卷积操作
注意:ViT在中小型数据集上可能不如CNN表现优异,但当训练数据超过1亿张图片时,其性能开始显著超越传统方法。
下表对比了ViT与典型CNN的核心差异:
| 特性 | ViT | CNN |
|---|---|---|
| 特征提取方式 | 全局自注意力 | 局部卷积核 |
| 空间信息处理 | 显式位置编码 | 隐式感受野累积 |
| 数据依赖性 | 需要大量训练数据 | 中等规模数据即可 |
| 计算复杂度 | O(n²) | O(n) |
2. 环境准备与数据预处理
2.1 安装必要依赖
确保你的Python环境包含以下核心库:
pip install torch torchvision pytorch-lightning einops2.2 CIFAR-10数据集处理
我们将使用CIFAR-10作为演示数据集。虽然原始ViT论文使用更大规模的ImageNet,但CIFAR-10更适合快速验证:
from torchvision import datasets, transforms # 定义数据增强策略 train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_data = datasets.CIFAR10('data', train=True, download=True, transform=train_transform) test_data = datasets.CIFAR10('data', train=False, transform=train_transform)3. ViT模型实现详解
3.1 图像分块与线性嵌入
ViT的第一步是将图像分割为固定大小的块并线性投影到特征空间:
import torch import torch.nn as nn from einops import rearrange class PatchEmbedding(nn.Module): def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64): super().__init__() self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # [B, C, H, W] -> [B, D, H/P, W/P] x = rearrange(x, 'b d h w -> b (h w) d') return x3.2 位置编码与分类令牌
Transformer需要位置信息来理解图像的空间结构:
class ViTEncoder(nn.Module): def __init__(self, num_patches, embed_dim, num_heads, num_layers): super().__init__() self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(embed_dim, num_heads), num_layers ) def forward(self, x): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embed return self.transformer(x)4. 完整模型组装与训练
4.1 构建端到端ViT模型
整合所有组件形成完整架构:
class VisionTransformer(nn.Module): def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=64, num_heads=4, num_layers=4, num_classes=10): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) num_patches = (img_size // patch_size) ** 2 self.encoder = ViTEncoder(num_patches, embed_dim, num_heads, num_layers) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): x = self.patch_embed(x) x = self.encoder(x) return self.head(x[:, 0]) # 使用分类令牌输出4.2 训练策略与超参数设置
使用PyTorch Lightning简化训练流程:
import pytorch_lightning as pl from torch.utils.data import DataLoader class ViTLightning(pl.LightningModule): def __init__(self, lr=1e-3): super().__init__() self.model = VisionTransformer() self.lr = lr self.criterion = nn.CrossEntropyLoss() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch preds = self(x) loss = self.criterion(preds, y) self.log('train_loss', loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) # 初始化训练器 trainer = pl.Trainer(max_epochs=50, gpus=1 if torch.cuda.is_available() else 0) model = ViTLightning() # 数据加载器 train_loader = DataLoader(train_data, batch_size=64, shuffle=True) test_loader = DataLoader(test_data, batch_size=64) # 开始训练 trainer.fit(model, train_loader)5. 模型优化与调参技巧
5.1 学习率调度策略
ViT训练对学习率非常敏感,推荐使用warmup策略:
def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.lr, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler]5.2 混合精度训练加速
利用NVIDIA GPU的Tensor Core加速训练:
trainer = pl.Trainer( max_epochs=50, precision=16, accelerator='gpu' if torch.cuda.is_available() else 'cpu' )5.3 关键超参数经验值
基于CIFAR-10的实验验证,以下配置表现良好:
| 参数 | 推荐值 | 说明 |
|---|---|---|
| patch_size | 4 | 平衡计算量与局部信息保留 |
| embed_dim | 64-128 | 特征维度 |
| num_heads | 4-8 | 注意力头数 |
| num_layers | 6-12 | Transformer层数 |
| batch_size | 64-128 | 根据GPU内存调整 |
6. 模型评估与结果分析
6.1 测试集性能评估
def test_step(self, batch, batch_idx): x, y = batch preds = self(x) loss = self.criterion(preds, y) acc = (preds.argmax(1) == y).float().mean() self.log('test_loss', loss) self.log('test_acc', acc) return {'loss': loss, 'acc': acc}6.2 可视化注意力机制
理解模型如何关注图像不同区域:
import matplotlib.pyplot as plt def visualize_attention(model, img): model.eval() with torch.no_grad(): patches = model.patch_embed(img.unsqueeze(0)) attns = model.encoder.transformer.layers[0].self_attn( patches, patches, patches )[1] plt.imshow(attns[0, 0, 1:].reshape(8, 8).cpu()) plt.colorbar() plt.show()在CIFAR-10上训练约50个epoch后,预期可以达到75-80%的测试准确率。虽然这低于原始论文在更大数据集上的结果,但足以验证ViT的基本原理。