别再死记硬背公式了!用Python手写一个Self-Attention层,带你彻底搞懂Transformer核心
2026/5/31 3:09:00 网站建设 项目流程

从零实现Self-Attention:用Python揭开Transformer核心机制的神秘面纱

当我在第一次接触Transformer模型时,那些复杂的矩阵运算和注意力权重图让我望而生畏。直到有一天,我决定亲手用代码实现一个Self-Attention层,那些抽象的概念突然变得清晰可见。本文将带你体验这段认知突破的旅程,通过Python代码实现,让你真正理解Self-Attention的内在机制。

1. 环境准备与基础概念

在开始编码之前,我们需要明确几个关键概念。Self-Attention机制的核心在于让模型能够动态地关注输入序列中不同位置的信息,而不是像RNN那样固定地处理序列。

首先创建一个新的Python环境并安装必要依赖:

conda create -n transformer python=3.8 conda activate transformer pip install torch numpy matplotlib

Self-Attention涉及三个核心矩阵:

  • Query(Q):表示当前需要关注的内容
  • Key(K):表示可供关注的内容
  • Value(V):实际被提取的信息

这三个矩阵都来自同一个输入,通过不同的权重矩阵变换得到。这种设计使得模型能够灵活地建立输入序列内部各元素间的关系。

2. 单头注意力实现

让我们从最基本的单头注意力开始。创建一个新的Python文件self_attention.py,首先实现核心的缩放点积注意力:

import torch import torch.nn as nn import torch.nn.functional as F class ScaledDotProductAttention(nn.Module): def __init__(self, temperature, attn_dropout=0.1): super().__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) def forward(self, q, k, v, mask=None): # q, k, v的形状: [batch_size, seq_len, d_k] attn = torch.matmul(q, k.transpose(1, 2)) / self.temperature if mask is not None: attn = attn.masked_fill(mask == 0, -1e9) attn = self.dropout(F.softmax(attn, dim=-1)) output = torch.matmul(attn, v) return output, attn

这段代码实现了注意力机制的核心计算:

  1. 计算Q和K的点积,得到原始注意力分数
  2. 用温度参数(√d_k)缩放这些分数
  3. 应用softmax归一化得到注意力权重
  4. 用这些权重对V进行加权求和

温度参数的作用是防止点积结果过大导致softmax进入梯度饱和区。我们可以通过一个简单的例子来验证这个实现:

d_k = 64 # 假设维度为64 attn = ScaledDotProductAttention(temperature=d_k**0.5) # 生成随机输入 (batch_size=1, seq_len=5, d_k=64) q = torch.randn(1, 5, d_k) k = torch.randn(1, 5, d_k) v = torch.randn(1, 5, d_k) output, attn_weights = attn(q, k, v) print(f"注意力权重形状: {attn_weights.shape}") print(f"输出形状: {output.shape}")

3. 完整Self-Attention层实现

现在我们将上面的核心注意力机制包装成一个完整的Self-Attention层:

class SelfAttention(nn.Module): def __init__(self, d_model, d_k, d_v, dropout=0.1): super().__init__() self.w_qs = nn.Linear(d_model, d_k, bias=False) self.w_ks = nn.Linear(d_model, d_k, bias=False) self.w_vs = nn.Linear(d_model, d_v, bias=False) self.attention = ScaledDotProductAttention(temperature=d_k**0.5) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x, mask=None): d_k, d_v = self.w_qs.out_features, self.w_vs.out_features batch_size, seq_len, _ = x.size() # 保存残差连接 residual = x # 计算Q, K, V q = self.w_qs(x) k = self.w_ks(x) v = self.w_vs(x) # 通过注意力机制 x, attn = self.attention(q, k, v, mask=mask) x = self.dropout(x) # 残差连接和层归一化 x += residual x = self.layer_norm(x) return x, attn

这个实现包含了几个关键设计:

  1. 三个独立的线性变换层分别生成Q、K、V
  2. 缩放点积注意力机制
  3. 残差连接和层归一化,这是Transformer架构稳定训练的关键

我们可以这样测试这个完整的Self-Attention层:

d_model = 512 # 模型维度 d_k = d_v = 64 # 通常key和value维度相同 sa = SelfAttention(d_model, d_k, d_v) x = torch.randn(1, 10, d_model) # batch_size=1, seq_len=10, d_model=512 output, attn = sa(x) print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"注意力矩阵形状: {attn.shape}")

4. 多头注意力机制

单一注意力头只能学习到一种关注模式,多头注意力允许模型同时关注来自不同位置的不同表示子空间的信息。下面是多头注意力的实现:

class MultiHeadAttention(nn.Module): def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v # 确保d_model可以被n_head整除 assert d_model % n_head == 0 self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) self.fc = nn.Linear(n_head * d_v, d_model, bias=False) self.attention = ScaledDotProductAttention(temperature=d_k**0.5) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x, mask=None): d_k, d_v, n_head = self.d_k, self.d_v, self.n_head batch_size, seq_len, _ = x.size() residual = x # 通过线性层并分割成多头 q = self.w_qs(x).view(batch_size, seq_len, n_head, d_k) k = self.w_ks(x).view(batch_size, seq_len, n_head, d_k) v = self.w_vs(x).view(batch_size, seq_len, n_head, d_v) # 转置以获得形状 [batch_size, n_head, seq_len, d_k/d_v] q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if mask is not None: mask = mask.unsqueeze(1) # 为头维度添加维度 # 通过注意力机制 x, attn = self.attention(q, k, v, mask=mask) # 转置回 [batch_size, seq_len, n_head, d_v] x = x.transpose(1, 2).contiguous() x = x.view(batch_size, seq_len, -1) # 合并最后两个维度 # 通过最终的线性层 x = self.dropout(self.fc(x)) x += residual x = self.layer_norm(x) return x, attn

多头注意力的关键步骤:

  1. 将Q、K、V通过更大的线性层投影到n_head * d_k/v维度
  2. 将结果分割成n_head个头
  3. 每个头独立计算注意力
  4. 将结果拼接并通过最终线性层

测试多头注意力:

n_head = 8 d_model = 512 d_k = d_v = 64 mha = MultiHeadAttention(n_head, d_model, d_k, d_v) x = torch.randn(1, 10, d_model) # batch_size=1, seq_len=10, d_model=512 output, attn = mha(x) print(f"输入形状: {x.shape}") print(f"输出形状: {output.shape}") print(f"注意力矩阵形状: {attn.shape}") # 应为 [1, 8, 10, 10]

5. 位置编码实现

由于Self-Attention不包含任何顺序信息,我们需要添加位置编码来注入序列的位置信息。以下是Transformer原论文中的正弦位置编码实现:

class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # 添加batch维度 self.register_buffer('pe', pe) def forward(self, x): # x的形状: [batch_size, seq_len, d_model] x = x + self.pe[:, :x.size(1)] return x

位置编码的关键特性:

  • 每个位置有唯一的编码
  • 编码是确定性的而非学习得到的
  • 可以处理比训练时更长的序列
  • 相对位置信息可以通过线性变换表示

我们可以可视化位置编码来理解其模式:

import matplotlib.pyplot as plt d_model = 512 max_len = 100 pe = PositionalEncoding(d_model, max_len) # 创建虚拟输入 x = torch.zeros(1, max_len, d_model) x_pe = pe(x) plt.figure(figsize=(12, 6)) plt.imshow(x_pe[0], cmap='hot', aspect='auto') plt.colorbar() plt.title("位置编码热图") plt.xlabel("维度") plt.ylabel("位置") plt.show()

6. 完整Self-Attention层的应用示例

现在我们将所有组件组合起来,展示如何在实践中使用Self-Attention层。以下是一个简单的文本处理示例:

# 假设我们有一些文本数据 texts = ["这是一个Self-Attention的实现示例", "我们将展示如何计算注意力权重"] # 简单的词汇表和嵌入层 vocab = {word: idx for idx, word in enumerate(set(" ".join(texts).split()))} vocab_size = len(vocab) d_model = 64 # 创建嵌入层 embedding = nn.Embedding(vocab_size, d_model) # 将文本转换为索引 inputs = [] for text in texts: words = text.split() indices = [vocab[word] for word in words] inputs.append(indices) # 填充序列到相同长度 max_len = max(len(seq) for seq in inputs) padded_inputs = [seq + [0]*(max_len - len(seq)) for seq in inputs] input_tensor = torch.tensor(padded_inputs) # 获取词嵌入 embeddings = embedding(input_tensor) # [batch_size, seq_len, d_model] # 添加位置编码 pe = PositionalEncoding(d_model) embeddings = pe(embeddings) # 通过Self-Attention层 sa = SelfAttention(d_model, d_k=32, d_v=32) output, attn_weights = sa(embeddings) print("输入序列形状:", embeddings.shape) print("输出序列形状:", output.shape) print("注意力权重形状:", attn_weights.shape)

这个示例展示了从原始文本到Self-Attention输出的完整流程。在实际应用中,你可能会使用更复杂的嵌入方法(如BERT)和更大的模型。

7. 注意力机制的可视化与分析

理解注意力权重是掌握Self-Attention机制的关键。让我们可视化前面示例中的注意力权重:

import seaborn as sns # 获取第一个样本的第一个头的注意力权重 attn_matrix = attn_weights[0].detach().numpy() # 获取对应的单词 words = texts[0].split() + [""]*(max_len - len(texts[0].split())) plt.figure(figsize=(10, 8)) sns.heatmap(attn_matrix, xticklabels=words, yticklabels=words, cmap="YlGnBu") plt.title("注意力权重可视化") plt.show()

通过分析注意力权重,我们可以发现:

  • 某些词对自身的注意力最强(对角线元素)
  • 语义相关的词之间会有较强的注意力连接
  • 不同头可能学习到不同的关注模式

在实际项目中,这种可视化是调试和理解模型行为的重要工具。例如,如果你发现模型总是忽略某些关键信息,可能需要调整注意力机制或添加额外的监督信号。

8. 性能优化与实用技巧

在实现和生产环境中使用Self-Attention时,有几个重要的性能考虑因素:

1. 计算复杂度优化

原始Self-Attention的计算复杂度是O(n²),对于长序列这会成为瓶颈。以下是一些优化策略:

# 内存高效的注意力实现 def memory_efficient_attention(q, k, v): # 分块计算注意力 chunk_size = 128 # 根据GPU内存调整 scores = torch.einsum('bhid,bhjd->bhij', q, k) scores = scores / (k.size(-1) ** 0.5) attn = torch.softmax(scores, dim=-1) output = torch.einsum('bhij,bhjd->bhid', attn, v) return output

2. 混合精度训练

使用混合精度可以显著减少内存占用并加速训练:

from torch.cuda.amp import autocast mha = MultiHeadAttention(n_head=8, d_model=512, d_k=64, d_v=64).cuda() optimizer = torch.optim.Adam(mha.parameters()) with autocast(): output, attn = mha(x.cuda()) loss = output.mean() optimizer.step()

3. 关键超参数选择

参数推荐值说明
d_model512模型维度,通常选择2的幂次
n_head8注意力头数,d_model应能被n_head整除
d_k, d_v64每个头的维度,通常d_k=d_v=d_model/n_head
dropout0.1用于注意力权重和输出的dropout率

4. 批处理技巧

对于变长序列,使用填充和掩码:

from torch.nn.utils.rnn import pad_sequence # 创建变长序列 sequences = [torch.randn(3, d_model), torch.randn(5, d_model), torch.randn(2, d_model)] # 填充序列 padded = pad_sequence(sequences, batch_first=True) # 创建掩码 mask = (padded != 0).all(dim=-1).unsqueeze(1).unsqueeze(2) output, attn = mha(padded, mask=mask)

9. 常见问题与调试技巧

在实现和使用Self-Attention时,可能会遇到以下常见问题:

问题1:注意力权重过于均匀或过于集中

解决方案

  • 检查温度参数是否正确实现
  • 尝试调整初始化方式
  • 添加小的随机噪声打破对称性
# 添加噪声的注意力计算 attn = torch.matmul(q, k.transpose(-2, -1)) / (d_k**0.5) attn = attn + torch.randn_like(attn) * 0.01 # 添加小噪声 attn = F.softmax(attn, dim=-1)

问题2:梯度消失或爆炸

解决方案

  • 确保使用了残差连接和层归一化
  • 监控梯度范数
  • 使用梯度裁剪
# 梯度裁剪示例 optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step()

问题3:长序列处理效率低下

解决方案

  • 考虑使用稀疏注意力或局部注意力
  • 尝试线性注意力变体
  • 使用内存高效的实现
# 局部注意力实现示例 def local_attention(q, k, v, window_size=32): seq_len = q.size(1) output = torch.zeros_like(v) for i in range(0, seq_len, window_size): start = max(0, i - window_size//2) end = min(seq_len, i + window_size//2) # 计算局部注意力 attn = torch.matmul(q[:, i:i+1], k[:, start:end].transpose(-2, -1)) attn = F.softmax(attn / (q.size(-1)**0.5), dim=-1) output[:, i:i+1] = torch.matmul(attn, v[:, start:end]) return output

10. 扩展应用与进阶思考

Self-Attention机制的应用远不止于Transformer模型。以下是一些值得探索的扩展方向:

1. 计算机视觉中的应用

class VisionSelfAttention(nn.Module): """适用于图像的Self-Attention实现""" def __init__(self, in_channels): super().__init__() self.query = nn.Conv2d(in_channels, in_channels//8, 1) self.key = nn.Conv2d(in_channels, in_channels//8, 1) self.value = nn.Conv2d(in_channels, in_channels, 1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x): batch_size, C, H, W = x.size() # 投影到query, key, value空间 q = self.query(x).view(batch_size, -1, H*W).permute(0, 2, 1) k = self.key(x).view(batch_size, -1, H*W) v = self.value(x).view(batch_size, -1, H*W) # 计算注意力 attn = torch.bmm(q, k) # [batch_size, H*W, H*W] attn = F.softmax(attn, dim=-1) # 应用注意力 out = torch.bmm(v, attn.permute(0, 2, 1)) out = out.view(batch_size, C, H, W) return self.gamma * out + x

2. 图数据处理

Self-Attention可以自然地应用于图数据,其中每个节点可以与图中的所有其他节点交互:

class GraphSelfAttention(nn.Module): """图数据的Self-Attention实现""" def __init__(self, node_dim): super().__init__() self.node_dim = node_dim self.q_proj = nn.Linear(node_dim, node_dim) self.k_proj = nn.Linear(node_dim, node_dim) self.v_proj = nn.Linear(node_dim, node_dim) def forward(self, nodes, adj_matrix=None): # nodes: [batch_size, num_nodes, node_dim] q = self.q_proj(nodes) k = self.k_proj(nodes) v = self.v_proj(nodes) # 计算注意力分数 attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.node_dim**0.5) # 如果提供了邻接矩阵,可以用它来mask注意力 if adj_matrix is not None: attn_scores = attn_scores.masked_fill(adj_matrix == 0, -1e9) attn_weights = F.softmax(attn_scores, dim=-1) output = torch.matmul(attn_weights, v) return output, attn_weights

3. 跨模态应用

Self-Attention特别适合处理多模态数据,例如同时处理图像和文本:

class CrossModalAttention(nn.Module): """跨模态注意力实现""" def __init__(self, dim1, dim2): super().__init__() self.dim1 = dim1 self.dim2 = dim2 self.q_proj = nn.Linear(dim1, dim2) self.k_proj = nn.Linear(dim2, dim2) self.v_proj = nn.Linear(dim2, dim2) def forward(self, modality1, modality2): # modality1: [batch_size, len1, dim1] # modality2: [batch_size, len2, dim2] q = self.q_proj(modality1) # 投影到dim2空间 k = self.k_proj(modality2) v = self.v_proj(modality2) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim2**0.5) attn_weights = F.softmax(attn_scores, dim=-1) output = torch.matmul(attn_weights, v) return output, attn_weights

在实现这些扩展应用时,关键是要理解Self-Attention的核心思想:通过可学习的、数据驱动的权重来决定不同部分信息的重要性,而不是依赖于固定的架构假设。这种灵活性正是Self-Attention机制强大之处。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询