1. 为什么需要手写Transformer?
第一次接触Transformer时,你可能会有这样的疑问:现在有这么多现成的深度学习框架(如HuggingFace的transformers库),为什么还要从零开始实现呢?这里我分享一个真实案例:去年我们团队在部署一个翻译模型时,发现直接调用预训练模型在长文本翻译时会出现内存泄漏。由于不熟悉底层实现,调试花了整整两周。而亲手实现过Transformer的同事,仅用两天就定位到是注意力矩阵的内存管理问题。
手写Transformer的三大核心价值:
- 维度魔术:真正理解
(batch_size, seq_len, d_model)这些张量在每一层的变换过程。比如多头注意力中d_k的缩放操作,只有亲手实现过才会明白为什么需要除以$\sqrt{d_k}$ - 调试能力:当模型输出异常时,能快速定位是mask机制问题还是残差连接问题。我曾遇到过一个bug,解码时总重复生成相同词汇,最后发现是解码器自注意力mask未正确设置
- 定制魔改:想给注意力加个稀疏约束?想尝试新型位置编码?只有掌握底层实现,才能灵活改造模型结构
下面这张表对比了不同学习方式的收益:
| 学习方式 | 理论理解 | 调试能力 | 改造灵活性 | 时间成本 |
|---|---|---|---|---|
| 直接调用API | ★★☆ | ★☆☆ | ★☆☆ | 低 |
| 阅读论文 | ★★★ | ★☆☆ | ★★☆ | 中 |
| 手写实现 | ★★★ | ★★★ | ★★★ | 高 |
2. 环境准备与数据预处理
2.1 极简PyTorch环境
推荐使用conda创建纯净环境:
conda create -n transformer python=3.8 conda activate transformer pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html pip install numpy matplotlib tqdm避坑指南:
- CUDA版本要与PyTorch对应(可通过
nvidia-smi查看) - 建议固定PyTorch版本,不同版本间张量操作可能有细微差异
- 如果遇到
RuntimeError: CUDA out of memory,尝试减小batch_size或使用梯度累积
2.2 玩具级数据集构建
为了聚焦模型实现,我们构造一个极简的德语→英语翻译数据集:
# 特殊标记 PAD = 0 # 填充标记 SOS = 1 # 句子开始 EOS = 2 # 句子结束 sentences = [ # 德语 输入 输出 ['ich mochte ein bier', 'S i want a beer', 'i want a beer E'], ['ich mochte ein cola', 'S i want a coke', 'i want a coke E'] ] vocab = { 'de': {'P': PAD, 'ich': 3, 'mochte': 4, 'ein': 5, 'bier': 6, 'cola': 7}, 'en': {'P': PAD, 'S': SOS, 'E': EOS, 'i': 8, 'want': 9, 'a': 10, 'beer': 11, 'coke': 12} }数据处理技巧:
- 序列填充:使用
torch.nn.utils.rnn.pad_sequence自动处理不等长序列 - 批量生成:
DataLoader的collate_fn参数可以自定义批次组装逻辑 - 设备转移:用
.to(device)统一管理数据位置
完整的数据管道实现:
class TranslationDataset(Dataset): def __init__(self, sentences, vocab): self.enc_inputs = [] self.dec_inputs = [] self.dec_outputs = [] for de, en_in, en_out in sentences: self.enc_inputs.append([vocab['de'][word] for word in de.split()]) self.dec_inputs.append([vocab['en'][word] for word in en_in.split()]) self.dec_outputs.append([vocab['en'][word] for word in en_out.split()]) def __getitem__(self, idx): return ( torch.LongTensor(self.enc_inputs[idx]), torch.LongTensor(self.dec_inputs[idx]), torch.LongTensor(self.dec_outputs[idx]) ) def collate_fn(batch): enc_inputs = [item[0] for item in batch] dec_inputs = [item[1] for item in batch] dec_outputs = [item[2] for item in batch] return ( pad_sequence(enc_inputs, batch_first=True, padding_value=PAD), pad_sequence(dec_inputs, batch_first=True, padding_value=PAD), pad_sequence(dec_outputs, batch_first=True, padding_value=PAD) ) dataset = TranslationDataset(sentences, vocab) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)3. Transformer核心组件实现
3.1 位置编码的数学之美
Transformer的位置编码采用正弦余弦函数,其精妙之处在于:
- 相对位置信息:通过三角函数特性,任意位置的编码都能表示为其他位置的线性组合
- 可扩展性:即使遇到比训练时更长的序列,也能生成合理的编码
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置 pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置 self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(1)].unsqueeze(0) # 自动广播到batch维度调试技巧:
- 可视化位置编码:用
plt.imshow(pe.numpy())检查是否呈现棋盘格模式 - 数值检查:确保相邻位置的编码差异适中(太大或太小都会影响训练)
3.2 注意力机制的三大核心
3.2.1 缩放点积注意力
def scaled_dot_product_attention(Q, K, V, mask=None): # Q/K/V形状: (batch_size, n_heads, seq_len, d_k) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 用极小值填充被mask的位置 attn = F.softmax(scores, dim=-1) return torch.matmul(attn, V), attn关键点:
- 缩放因子$\sqrt{d_k}$防止点积过大导致softmax梯度消失
- mask操作要在softmax之前完成
3.2.2 多头注意力
class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() assert d_model % n_heads == 0 self.d_k = d_model // n_heads self.n_heads = n_heads self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def forward(self, Q, K, V, mask): # 线性变换 + 分头 Q = self.W_Q(Q).view(-1, Q.size(1), self.n_heads, self.d_k).transpose(1, 2) K = self.W_K(K).view(-1, K.size(1), self.n_heads, self.d_k).transpose(1, 2) V = self.W_V(V).view(-1, V.size(1), self.n_heads, self.d_k).transpose(1, 2) # 计算注意力 if mask is not None: mask = mask.unsqueeze(1) # 广播到所有头 x, attn = scaled_dot_product_attention(Q, K, V, mask) # 合并多头 x = x.transpose(1, 2).contiguous().view(-1, x.size(2), self.n_heads * self.d_k) return self.W_O(x)维度变换解析:
- 输入形状:
(batch_size, seq_len, d_model) - 线性变换后:
(batch_size, seq_len, d_model) - 分头操作:
(batch_size, seq_len, n_heads, d_k)→ 转置为(batch_size, n_heads, seq_len, d_k) - 注意力计算后:保持形状不变
- 合并输出:
(batch_size, seq_len, d_model)
3.2.3 掩码机制
Transformer使用两种掩码:
- 填充掩码:避免注意力机制处理填充符号
- 序列掩码:防止解码器看到未来信息
def create_masks(enc_input, dec_input): # 编码器掩码(仅padding) enc_mask = (enc_input != PAD).unsqueeze(1).unsqueeze(2) # 解码器掩码(padding + future) dec_pad_mask = (dec_input != PAD).unsqueeze(1).unsqueeze(2) seq_len = dec_input.size(1) dec_seq_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(dec_input.device) dec_mask = dec_pad_mask & dec_seq_mask return enc_mask, dec_mask4. 模型训练与调试技巧
4.1 学习率调度策略
Transformer使用带预热(warmup)的学习率调度:
class TransformerOptimizer: def __init__(self, optimizer, d_model, warmup_steps=4000): self.optimizer = optimizer self.d_model = d_model self.warmup_steps = warmup_steps self.current_step = 0 def step(self): self.current_step += 1 lr = self.d_model ** -0.5 * min(self.current_step ** -0.5, self.current_step * self.warmup_steps ** -1.5) for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.optimizer.step()训练曲线解读:
- 初期:学习率线性增长,避免冷启动
- 中期:随步数平方根衰减
- 后期:稳定在小学习率微调
4.2 梯度裁剪与损失函数
criterion = nn.CrossEntropyLoss(ignore_index=PAD) optimizer = TransformerOptimizer( torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), d_model=512 ) def train_step(batch): enc_input, dec_input, dec_output = batch enc_mask, dec_mask = create_masks(enc_input, dec_input) pred = model(enc_input, dec_input, enc_mask, dec_mask) loss = criterion(pred.view(-1, pred.size(-1)), dec_output.view(-1)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() return loss.item()常见问题排查:
- NaN损失:检查注意力分数是否在softmax前被正确mask
- 梯度爆炸:调小学习率或增强梯度裁剪
- 欠拟合:增加模型深度或检查数据预处理
5. 完整模型组装
5.1 编码器实现
class EncoderLayer(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_heads) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): attn_output = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) ffn_output = self.ffn(x) return self.norm2(x + self.dropout(ffn_output)) class Encoder(nn.Module): def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout=0.1): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model) self.layers = nn.ModuleList([ EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) def forward(self, x, mask): x = self.pos_encoding(self.embedding(x)) for layer in self.layers: x = layer(x, mask) return x5.2 解码器实现
class DecoderLayer(nn.Module): def __init__(self, d_model, n_heads, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_heads) self.enc_attn = MultiHeadAttention(d_model, n_heads) self.ffn = nn.Sequential( nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask, tgt_mask): # 自注意力(带未来信息mask) attn_output = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_output)) # 编码器-解码器注意力 attn_output = self.enc_attn(x, enc_output, enc_output, src_mask) x = self.norm2(x + self.dropout(attn_output)) ffn_output = self.ffn(x) return self.norm3(x + self.dropout(ffn_output)) class Decoder(nn.Module): def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, dropout=0.1): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = PositionalEncoding(d_model) self.layers = nn.ModuleList([ DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) def forward(self, x, enc_output, src_mask, tgt_mask): x = self.pos_encoding(self.embedding(x)) for layer in self.layers: x = layer(x, enc_output, src_mask, tgt_mask) return x5.3 Transformer完整架构
class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_layers=6, n_heads=8, d_ff=2048, dropout=0.1): super().__init__() self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_heads, d_ff, dropout) self.decoder = Decoder(tgt_vocab_size, d_model, n_layers, n_heads, d_ff, dropout) self.fc = nn.Linear(d_model, tgt_vocab_size) def forward(self, src, tgt, src_mask, tgt_mask): enc_output = self.encoder(src, src_mask) dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask) return self.fc(dec_output)6. 模型部署与推理
6.1 贪婪解码实现
def greedy_decode(model, src, src_mask, max_len=20, start_symbol=SOS): memory = model.encoder(src, src_mask) ys = torch.ones(1, 1).fill_(start_symbol).type_as(src) for _ in range(max_len-1): tgt_mask = create_decoder_mask(ys) out = model.decoder(ys, memory, src_mask, tgt_mask) prob = model.fc(out[:, -1]) next_word = prob.argmax(dim=-1) ys = torch.cat([ys, next_word.unsqueeze(0)], dim=1) if next_word == EOS: break return ys6.2 可视化注意力权重
def plot_attention(attention, src_sentence, tgt_sentence): fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) cax = ax.matshow(attention.numpy(), cmap='bone') fig.colorbar(cax) ax.set_xticklabels([''] + src_sentence.split(), rotation=90) ax.set_yticklabels([''] + tgt_sentence.split()) return fig