用Python手写Self-Attention层:5行代码拆解Transformer核心
当你第一次看到Transformer模型中的Self-Attention公式时,是否被那些矩阵运算和归一化步骤绕晕了?今天我们不谈数学推导,直接动手用Python实现一个完整的Self-Attention层。通过这段可运行的代码,你将亲眼看到:
- 输入序列如何神奇地"自我关注"
- Q、K、V矩阵的实际计算过程
- 注意力权重如何动态分配
- 最终输出如何捕捉上下文关系
1. 环境准备与输入构造
在开始之前,确保你的Python环境已安装NumPy库。我们将用这个科学计算库来处理所有矩阵运算。先构造一个简单的输入序列 - 假设我们有两个词向量,每个向量的维度是4:
import numpy as np # 构造输入序列 (2个token,每个token的embedding维度为4) X = np.array([ [1.0, 0.5, -1.2, 2.3], # 第一个词向量 [0.8, -1.3, 0.6, 1.1] # 第二个词向量 ])为什么选择这样的维度?在实际应用中:
- 词向量维度通常在几十到几百之间(如BERT-base使用768维)
- 序列长度可以是任意数量的token
- 这里的小规模数据方便我们逐步验证计算过程
2. 初始化权重矩阵
Self-Attention的核心是三个可学习的权重矩阵:WQ、WK和WV。它们将原始输入投影到不同的表示空间:
# 初始化权重矩阵 (输入维度4,输出维度3) np.random.seed(42) # 固定随机种子便于复现 WQ = np.random.randn(4, 3) * 0.1 WK = np.random.randn(4, 3) * 0.1 WV = np.random.randn(4, 3) * 0.1 print("WQ:\n", WQ) print("WK:\n", WK) print("WV:\n", WV)这三个矩阵的作用分别是:
| 矩阵 | 功能描述 | 输出形状 |
|---|---|---|
| WQ | 生成查询向量 | (4,3) |
| WK | 生成键向量 | (4,3) |
| WV | 生成值向量 | (4,3) |
在实际训练中,这些权重会通过反向传播不断调整,使模型学会关注输入序列中最重要的部分。
3. 计算Q、K、V矩阵
现在让我们计算查询(Query)、键(Key)和值(Value)矩阵:
# 计算Q、K、V Q = X @ WQ # (2,4) @ (4,3) -> (2,3) K = X @ WK # (2,4) @ (4,3) -> (2,3) V = X @ WV # (2,4) @ (4,3) -> (2,3) print("\nQ矩阵:\n", Q) print("K矩阵:\n", K) print("V矩阵:\n", V)观察这些矩阵,你会发现:
- 每个token现在有三个不同的表示
- Q向量决定"要查找什么"
- K向量决定"如何被查找"
- V向量是实际被加权的信息
4. 注意力得分计算与归一化
接下来计算注意力得分,这是Self-Attention最核心的部分:
# 计算注意力得分 d_k = K.shape[1] # 键向量的维度(3) scores = Q @ K.T / np.sqrt(d_k) # (2,3) @ (3,2) -> (2,2) attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True) print("\n原始注意力得分:\n", scores) print("Softmax后的注意力权重:\n", attention_weights)关键点解析:
Q @ K.T计算每对token之间的相关性√d_k缩放防止梯度消失(当d_k较大时点积结果可能过大)- Softmax确保每行的权重和为1
5. 加权求和生成最终输出
最后一步是用注意力权重对V矩阵进行加权求和:
# 计算加权输出 output = attention_weights @ V # (2,2) @ (2,3) -> (2,3) print("\n最终输出:\n", output)这个输出矩阵的神奇之处在于:
- 每个token的新表示都融合了整个序列的信息
- 权重动态分配,无需人工设定规则
- 模型自动学习哪些token关系更密切
6. 完整代码实现
将上述步骤整合成一个完整的Self-Attention函数:
def self_attention(X, WQ, WK, WV): # 计算Q,K,V Q = X @ WQ K = X @ WK V = X @ WV # 计算注意力得分 d_k = K.shape[1] scores = Q @ K.T / np.sqrt(d_k) weights = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True) # 加权求和 return weights @ V # 测试函数 output = self_attention(X, WQ, WK, WV) print("完整Self-Attention输出:\n", output)7. 多头注意力机制简介
单头注意力有时难以捕捉复杂的上下文关系。多头注意力通过并行运行多组注意力机制来增强模型能力:
- 定义多组WQ、WK、WV
- 分别计算每组注意力
- 拼接所有结果并通过线性变换合并
# 假设我们有两组注意力头 WQ1, WK1, WV1 = np.random.randn(4,3), np.random.randn(4,3), np.random.randn(4,3) WQ2, WK2, WV2 = np.random.randn(4,3), np.random.randn(4,3), np.random.randn(4,3) # 计算两组注意力 head1 = self_attention(X, WQ1, WK1, WV1) head2 = self_attention(X, WQ2, WK2, WV2) # 拼接结果 multi_head_output = np.concatenate([head1, head2], axis=1) print("\n多头注意力输出(拼接后):\n", multi_head_output)8. 实际应用中的优化技巧
在真实场景中,Self-Attention还会加入以下改进:
- 位置编码:为输入序列添加位置信息
- 残差连接:缓解梯度消失问题
- 层归一化:稳定训练过程
- 掩码机制:处理变长序列
例如,Transformer中的完整注意力层可能这样实现:
class MultiHeadAttention: def __init__(self, d_model, num_heads): self.d_model = d_model self.num_heads = num_heads self.depth = d_model // num_heads # 初始化权重矩阵 self.WQ = np.random.randn(d_model, d_model) self.WK = np.random.randn(d_model, d_model) self.WV = np.random.randn(d_model, d_model) self.WO = np.random.randn(d_model, d_model) def split_heads(self, x): # 将输入拆分为多头 return x.reshape(x.shape[0], self.num_heads, -1) def call(self, X): Q = X @ self.WQ K = X @ self.WK V = X @ self.WV # 拆分多头 Q = self.split_heads(Q) K = self.split_heads(K) V = self.split_heads(V) # 计算缩放点积注意力 scores = Q @ K.transpose(0,2,1) / np.sqrt(self.depth) weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True) output = weights @ V # 合并多头并线性变换 output = output.reshape(X.shape[0], -1) return output @ self.WO9. 调试与可视化技巧
为了更好地理解Self-Attention的工作原理,可以尝试以下方法:
- 打印中间结果:观察Q、K、V矩阵的变化
- 可视化注意力权重:用热图显示token间的关系
- 修改输入序列:测试不同输入对注意力的影响
- 梯度检查:验证反向传播的正确性
例如,用matplotlib可视化注意力权重:
import matplotlib.pyplot as plt plt.imshow(attention_weights, cmap='viridis') plt.colorbar() plt.xlabel("Key Positions") plt.ylabel("Query Positions") plt.title("Attention Weights Heatmap") plt.show()10. 从理论到实践的思考
通过这个实现,我深刻体会到Self-Attention的几个关键特性:
- 动态权重分配:不像RNN有固定的计算路径
- 并行计算能力:所有token同时处理
- 长距离依赖:直接建模任意距离的关系
- 可解释性:注意力权重显示模型关注点
在实际项目中,我发现调整这些参数会影响模型表现:
- 维度大小:太小导致信息瓶颈,太大增加计算量
- 头数选择:需要平衡多样性和计算开销
- 缩放因子:对训练稳定性至关重要