残差连接 + 层归一化:Transformer的核心“配方”
去掉它们,Transformer可能只是一堆层叠的线性变换
在Transformer架构中,多头自注意力和前馈网络通常是大家关注的明星组件。但真正让训练数十层甚至上百层网络成为可能的,却是两个看似简单的操作:残差连接和层归一化。
本文将深入浅出地解释这对组合是如何工作的,以及为什么它们对Transformer的成功至关重要。
1. 问题的来源:深度网络的训练困境
当神经网络变深时,会出现两个经典问题:
梯度消失/爆炸:反向传播时,梯度逐层相乘,要么趋近于0,要么爆炸性增长。
表示退化:网络难以学习恒等映射(即“这层什么都不做”),导致增加层数反而降低性能。
Transformer(如BERT、GPT)动辄12、24、48层,如果没有特殊设计,根本训练不动。
2. 残差连接(Residual Connection)
2.1 原始公式
残差连接的写法非常简单:
output=Layer(x)+xoutput=Layer(x)+x
其中 xx 是输入,Layer(⋅)Layer(⋅) 可以是自注意力或前馈网络。
2.2 为什么有效?
梯度高速通道:反向传播时,梯度可以直接通过“+x+x”这条路径跳过变换层,避免逐层衰减。
学习残差函数:让网络只需学习输入与输出的差异(残差),而非完整映射。如果最佳映射就是恒等,网络只需把残差部分推为零,这比直接学习恒等容易得多。
缓解退化:即使新增层暂时学不到有用信息,残差连接也能保证性能至少不下降。
2.3 直观类比
想象你在画一幅画,但每一笔都不直接覆盖原图,而是画在一个透明图层上,最后与原始图层叠加。如果你画坏了,原始内容仍在;画好了,效果增强。残差连接就是这种“叠加层”思想。
3. 层归一化(Layer Normalization)
3.1 计算方式
对于单个样本的一个特征向量,层归一化如下计算:
μ=1H∑i=1Hai,σ2=1H∑i=1H(ai−μ)2μ=H1i=1∑Hai,σ2=H1i=1∑H(ai−μ)2a^i=ai−μσ2+ϵ,outputi=γa^i+βa^i=σ2+ϵai−μ,outputi=γa^i+β
其中 HH 是特征维度(例如512或768),γγ 和 ββ 是可学习的缩放与偏移参数。
3.2 与批归一化(Batch Normalization)的区别
| 批归一化 | 层归一化 | |
|---|---|---|
| 归一化维度 | 批次维度 | 特征维度 |
| 对batch size依赖 | 强,小batch不稳定 | 无依赖 |
| 适用场景 | CNN | RNN、Transformer |
Transformer选择层归一化的关键原因:序列长度可变,且不同样本间统计量差异大,不适合共享batch统计信息。
3.3 为什么Transformer需要它?
稳定梯度:把每层输入的分布拉回均值为0、方差为1的范围,避免激活值落入饱和区。
加速收敛:降低对学习率的敏感性,允许更大学习率。
适应不同序列长度:每个样本独立归一化,自然支持变长输入。
4. 经典组合方式:Post-LN vs Pre-LN
在原始Transformer论文中(Vaswani et al., 2017),顺序是:
Output=LayerNorm(x+Sublayer(x))Output=LayerNorm(x+Sublayer(x))
这称为Post-LN(先残差后归一化)。但在深层Transformer(如BERT-large)中,Post-LN容易导致训练不稳定或梯度消失。
现代实践(GPT-2、GPT-3、大多数开源实现)改为Pre-LN:
Output=x+Sublayer(LayerNorm(x))Output=x+Sublayer(LayerNorm(x))
4.1 对比
Pre-LN(先归一化,再变换,最后残差):
残差路径上的信号几乎无缩放,梯度流更顺畅。
对学习率和初始化鲁棒性更强。
训练更深(几十层)时更稳定。
Post-LN(先残差,再归一化):
原始论文设计,对学习率敏感,需要warmup。
深层时容易在初始化阶段梯度爆炸。
结论:现代Transformer几乎都默认使用Pre-LN。
5. 完整代码示例(PyTorch)
python
import torch import torch.nn as nn class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) self.ffn = nn.Sequential( nn.Linear(d_model, 4*d_model), nn.ReLU(), nn.Linear(4*d_model, d_model) ) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): # Pre-LN 风格 # 1. 自注意力 + 残差 attn_out, _ = self.attention(self.norm1(x), self.norm1(x), self.norm1(x)) x = x + self.dropout(attn_out) # 2. 前馈网络 + 残差 ffn_out = self.ffn(self.norm2(x)) x = x + self.dropout(ffn_out) return x6. 直观理解:配方而非零件
如果把Transformer比作一道菜:
注意力机制= 主料(比如牛肉)
前馈网络= 配菜(比如青椒)
残差连接= 保留原汁原味的“不破坏食材”
层归一化= 每一步调味(让味道均匀)
没有后两者,食材堆叠再多也只是混乱,无法做出稳定、可口的深层网络。
7. 小结
| 组件 | 核心作用 | 解决什么问题 |
|---|---|---|
| 残差连接 | 梯度高速路 + 学习残差 | 梯度消失、网络退化 |
| 层归一化 | 稳定分布、加速收敛 | 内部协变量偏移、训练不稳定 |
它们结合在一起,使Transformer可以轻松扩展到数百层,成为现代大语言模型的基础构建块。下次你阅读Transformer或BERT的代码时,请多留意这两个简单却关键的组件——它们是整座大厦的地基。