🚀 30+款热门AI模型一站整合,DeepSeek/GLM/Qwen 随心用,限时 5 折。 👉 点击领海量免费额度
1. 背景与核心概念
Transformer 架构自 2017 年由 Google 团队在论文《Attention Is All You Need》中提出以来,已经彻底改变了自然语言处理乃至整个深度学习领域。它摒弃了传统的循环神经网络和卷积神经网络在处理序列数据时的固有缺陷,通过一种全新的、完全基于注意力机制的架构,实现了前所未有的并行化能力和对长距离依赖关系的强大建模能力。如今,从 ChatGPT、GPT-4 这样的对话模型,到 BERT、T5 这样的理解模型,再到 Stable Diffusion、Sora 这样的多模态生成模型,其核心都离不开 Transformer。
对于许多开发者而言,Transformer 的原理常常被其复杂的数学公式和架构图所掩盖,感觉“高深莫测”。实际上,其核心思想非常直观。本文将彻底拆解 Transformer 的每一个组件,从最基础的注意力机制开始,逐步构建起完整的编码器-解码器架构,并结合代码示例和实际应用场景,让你不仅能理解其原理,更能掌握其实现细节。无论你是希望深入理解大模型背后的技术,还是计划在自己的项目中应用 Transformer 架构,这篇文章都将为你提供一个清晰、系统、可实践的指南。
2. 从序列建模的困境到注意力机制
在 Transformer 出现之前,序列建模(如机器翻译、文本生成)的主流是循环神经网络及其变体 LSTM 和 GRU。这些模型按顺序处理输入序列,将之前步骤的信息保存在一个“隐藏状态”中。然而,这种顺序处理方式存在两个根本性瓶颈:
- 并行化困难:由于每一步的计算都依赖于上一步的隐藏状态,模型无法充分利用现代 GPU 的并行计算能力,训练速度慢。
- 长距离依赖遗忘:对于较长的序列,早期输入的信息在传递过程中会逐渐衰减或丢失,模型难以捕捉序列开头和结尾之间的关联。
注意力机制的引入是解决第二个问题的关键一步。它允许模型在生成输出序列的每一个词时,直接“查看”输入序列中的所有词,并动态地为每个输入词分配一个“关注度”权重。这就像人在翻译句子时,会反复回看原文的不同部分一样。
最初的注意力机制被用在基于 RNN 的编码器-解码器模型中,但它仍然依赖于 RNN 来生成编码表示,因此第一个瓶颈(并行化困难)依然存在。
Transformer 的革命性在于:它完全抛弃了循环结构,仅使用注意力机制来构建整个模型。这使得模型可以一次性处理整个输入序列,所有词之间的关联计算都可以并行进行,极大地提升了训练效率。
3. Transformer 核心组件详解
一个标准的 Transformer 模型主要由编码器和解码器堆叠而成。我们先来深入理解构成它们的基础模块。
3.1 输入表示:词嵌入与位置编码
Transformer 的输入是一系列词元。首先,每个词元通过一个词嵌入层被映射为一个高维向量。这个向量捕获了词义的语义信息。
然而,自注意力机制本身是对顺序不敏感的。对于句子 “狗咬人” 和 “人咬狗”,如果不提供位置信息,模型会认为它们是相同的。因此,我们必须注入位置信息。
位置编码为序列中每个位置的词元嵌入向量添加一个独特的向量。原论文使用了一种基于正弦和余弦函数的固定编码方式:
对于位置pos和维度i,其计算公式为:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))其中d_model是嵌入向量的维度。
这种编码方式的优点是:它能产生一种有界、平滑的位置表示,并且模型可以轻松学会关注相对位置(因为PE(pos+k)可以表示为PE(pos)的线性函数)。
import numpy as np import torch import torch.nn as nn class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super(PositionalEncoding, self).__init__() # 创建一个足够长的位置编码矩阵 pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度用sin pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度用cos pe = pe.unsqueeze(0) # 形状: (1, max_len, d_model) self.register_buffer('pe', pe) # 将其注册为缓冲区,不参与梯度更新 def forward(self, x): # x 形状: (batch_size, seq_len, d_model) x = x + self.pe[:, :x.size(1)] return x # 示例 d_model = 512 seq_len = 10 batch_size = 2 embedding = torch.randn(batch_size, seq_len, d_model) pos_encoder = PositionalEncoding(d_model) output_with_pos = pos_encoder(embedding) print(f"输入嵌入形状: {embedding.shape}") print(f"加入位置编码后形状: {output_with_pos.shape}")3.2 缩放点积注意力
这是 Transformer 的灵魂。其核心思想是:对于序列中的每一个元素(查询 Query),计算它与序列中所有元素(键 Key)的相关性,然后用这个相关性权重对对应的值(Value)进行加权求和,从而得到一个融合了全局上下文信息的表示。
计算步骤:
- 线性变换:将输入序列通过三个不同的权重矩阵
W_Q,W_K,W_V投影,得到查询矩阵Q、键矩阵K和值矩阵V。 - 计算注意力分数:计算
Q和K的点积,度量查询和键之间的相关性。分数越高,表示相关性越强。 - 缩放:将点积结果除以
sqrt(d_k),其中d_k是键向量的维度。这一步是为了防止点积结果过大,导致经过 softmax 后梯度消失。 - 归一化:对缩放后的分数应用 softmax 函数,将其转化为概率分布(权重和为1)。
- 加权求和:用 softmax 得到的权重对
V进行加权求和,得到最终的输出。
公式表示:
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * Vimport torch import torch.nn.functional as F def scaled_dot_product_attention(query, key, value, mask=None): """ query: 形状 (..., seq_len_q, d_k) key: 形状 (..., seq_len_k, d_k) value: 形状 (..., seq_len_v, d_v) ,通常 seq_len_k == seq_len_v mask: 形状 (..., seq_len_q, seq_len_k) """ d_k = query.size(-1) # 计算点积并缩放 scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) if mask is not None: # 将 mask 中为 True 的位置替换为一个非常大的负数,这样 softmax 后权重接近 0 scores = scores.masked_fill(mask == 0, -1e9) # 计算注意力权重 attention_weights = F.softmax(scores, dim=-1) # 加权求和 output = torch.matmul(attention_weights, value) return output, attention_weights # 示例:自注意力 (seq_len_q = seq_len_k = seq_len_v) batch_size = 2 seq_len = 5 d_model = 64 d_k = d_v = d_model query = torch.randn(batch_size, seq_len, d_k) key = torch.randn(batch_size, seq_len, d_k) value = torch.randn(batch_size, seq_len, d_v) output, attn_weights = scaled_dot_product_attention(query, key, value) print(f"注意力输出形状: {output.shape}") # (2, 5, 64) print(f"注意力权重形状: {attn_weights.shape}") # (2, 5, 5) # 对于第0个样本,第0个词元对其他所有词元的注意力权重 print(f"示例注意力权重(第一个样本,第一个词元): {attn_weights[0, 0]}")3.3 多头注意力
单一的注意力机制可能只关注到一种类型的语义关系。为了让模型能够同时关注来自不同表示子空间的信息,Transformer 引入了多头注意力。
其做法是:
- 将
Q,K,V通过h个不同的线性投影矩阵,分别投影到h个更小的空间(d_k,d_v维度)。 - 在每个投影后的子空间上独立执行缩放点积注意力,得到
h个输出。 - 将这
h个输出拼接起来,再通过一个最终的线性投影层W_O融合信息。
公式表示:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O where head_i = Attention(Q * W_Q_i, K * W_K_i, V * W_V_i)import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # 定义投影矩阵 self.W_q = nn.Linear(d_model, d_model) # 实际实现中,通常先投影到 d_model self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def split_heads(self, x): """将输入张量从 (batch_size, seq_len, d_model) 重塑为 (batch_size, num_heads, seq_len, d_k)""" batch_size, seq_len, _ = x.size() return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # 1. 线性投影 Q = self.W_q(query) # (batch_size, seq_len_q, d_model) K = self.W_k(key) # (batch_size, seq_len_k, d_model) V = self.W_v(value) # (batch_size, seq_len_v, d_model) # 2. 分割成多个头 Q = self.split_heads(Q) # (batch_size, num_heads, seq_len_q, d_k) K = self.split_heads(K) # (batch_size, num_heads, seq_len_k, d_k) V = self.split_heads(V) # (batch_size, num_heads, seq_len_v, d_k) # 3. 为每个头计算缩放点积注意力 # 我们需要调整 mask 的维度以匹配多头 if mask is not None: mask = mask.unsqueeze(1) # (batch_size, 1, seq_len_q, seq_len_k) -> 广播到每个头 # 计算注意力,这里调用之前定义的函数,但需要处理多头维度 # 简便起见,我们重塑张量,将 num_heads 视为 batch 维度的一部分 Q_reshaped = Q.transpose(1, 2).contiguous().view(batch_size * self.num_heads, -1, self.d_k) K_reshaped = K.transpose(1, 2).contiguous().view(batch_size * self.num_heads, -1, self.d_k) V_reshaped = V.transpose(1, 2).contiguous().view(batch_size * self.num_heads, -1, self.d_k) if mask is not None: mask_reshaped = mask.repeat(1, self.num_heads, 1, 1).view(batch_size * self.num_heads, -1, mask.size(-1)) attn_output, _ = scaled_dot_product_attention(Q_reshaped, K_reshaped, V_reshaped, mask_reshaped if mask is not None else None) # 4. 合并多头输出 attn_output = attn_output.view(batch_size, self.num_heads, -1, self.d_k).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # 5. 最终线性投影 output = self.W_o(attn_output) return output # 示例 d_model = 512 num_heads = 8 mha = MultiHeadAttention(d_model, num_heads) seq_len = 10 batch_size = 4 x = torch.randn(batch_size, seq_len, d_model) # 假设是自注意力,Q=K=V=x output = mha(x, x, x) print(f"多头注意力输入形状: {x.shape}") print(f"多头注意力输出形状: {output.shape}") # 应保持 (4, 10, 512)3.4 前馈网络
每个编码器和解码器层中的注意力子层后面都跟着一个前馈网络。这是一个简单的两层全连接神经网络,通常中间层的维度更大(例如,d_ff = 4 * d_model),并带有 ReLU 激活函数。
它的作用是对每个位置的表示进行独立、相同的非线性变换,增加模型的表达能力。
FFN(x) = max(0, x * W1 + b1) * W2 + b2class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() def forward(self, x): return self.linear2(self.dropout(self.activation(self.linear1(x)))) # 示例 d_model = 512 d_ff = 2048 ffn = PositionwiseFeedForward(d_model, d_ff) x = torch.randn(4, 10, d_model) output = ffn(x) print(f"前馈网络输出形状: {output.shape}") # (4, 10, 512)3.5 残差连接与层归一化
为了缓解深层网络中的梯度消失问题,并稳定训练,Transformer 在每个子层(自注意力、前馈网络)周围都使用了残差连接,并在其后进行层归一化。
残差连接:将子层的输入直接加到其输出上,即Output = LayerNorm(x + Sublayer(x))。这确保了信息可以更直接地向前传播。
层归一化:对单个样本的所有特征维度进行归一化,使其均值为0,方差为1。这有助于稳定训练过程,加速收敛。现代实现更常用Pre-LN结构,即先做层归一化,再进入子层:Output = x + Sublayer(LayerNorm(x)),这被证明训练更稳定。
class SublayerConnection(nn.Module): """一个残差连接,后接层归一化。注意为了简化,我们使用 Post-LN。""" def __init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm = nn.LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, sublayer): """应用残差连接到任何与 x 相同形状的子层。""" # Post-LN: LayerNorm(x + Sublayer(x)) return x + self.dropout(sublayer(self.norm(x))) # 如果是 Pre-LN,则应为:return x + self.dropout(sublayer(self.norm(x))) # 注意:Pre-LN 中 norm 在 sublayer 内部调用,这里仅为示意。4. 编码器与解码器架构
4.1 编码器层
一个编码器层由两个主要子层构成:
- 多头自注意力层:输入序列自己对自己做注意力,让每个词元都能关注到序列中所有其他词元,从而获得包含全局上下文的表示。
- 前馈网络层:对每个位置的表示进行独立变换。
每个子层周围都有残差连接和层归一化。
class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super(EncoderLayer, self).__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(2)]) self.size = d_model def forward(self, x, mask): """ x: 输入张量,形状 (batch_size, seq_len, d_model) mask: 用于自注意力的 mask,形状 (batch_size, 1, seq_len, seq_len) 或 (batch_size, seq_len, seq_len) """ # 第一个子层:多头自注意力 (带残差和归一化) x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) # 第二个子层:前馈网络 (带残差和归一化) x = self.sublayer[1](x, self.feed_forward) return x4.2 解码器层
解码器层比编码器层多一个子层,共三个:
- 掩码多头自注意力层:防止解码器在预测当前位置时“偷看”未来的信息。这是通过一个因果掩码实现的,该掩码将注意力权重矩阵右上三角部分(未来位置)设置为负无穷大。
- 编码器-解码器注意力层(交叉注意力):让解码器能够关注编码器的最终输出。其中,查询
Q来自解码器的上一子层输出,而键K和值V来自编码器的输出。 - 前馈网络层:与编码器相同。
class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super(DecoderLayer, self).__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) self.sublayer = nn.ModuleList([SublayerConnection(d_model, dropout) for _ in range(3)]) self.size = d_model def forward(self, x, encoder_output, src_mask, tgt_mask): """ x: 解码器输入 (或上一层的输出),形状 (batch_size, tgt_seq_len, d_model) encoder_output: 编码器输出,形状 (batch_size, src_seq_len, d_model) src_mask: 源序列 mask (用于编码器-解码器注意力,可选) tgt_mask: 目标序列 mask (用于解码器自注意力,因果掩码) """ # 第一子层:掩码自注意力 x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) # 第二子层:编码器-解码器注意力 # Q 来自解码器,K, V 来自编码器 x = self.sublayer[1](x, lambda x: self.cross_attn(x, encoder_output, encoder_output, src_mask)) # 第三子层:前馈网络 x = self.sublayer[2](x, self.feed_forward) return x4.3 构建完整的 Transformer
现在我们可以将编码器层和解码器层堆叠起来,并加上嵌入层和最后的线性输出层,构建一个完整的 Transformer 模型。
class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_encoder_layers=6, num_decoder_layers=6, d_ff=2048, max_seq_len=5000, dropout=0.1): super(Transformer, self).__init__() self.encoder_embedding = nn.Embedding(src_vocab_size, d_model) self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model) self.positional_encoding = PositionalEncoding(d_model, max_seq_len) self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_encoder_layers) ]) self.decoder_layers = nn.ModuleList([ DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_decoder_layers) ]) self.final_linear = nn.Linear(d_model, tgt_vocab_size) self.dropout = nn.Dropout(dropout) self.d_model = d_model def generate_square_subsequent_mask(self, sz): """生成因果掩码 (下三角为 True,上三角为 False)""" mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask def forward(self, src, tgt, src_mask=None, tgt_mask=None): """ src: 源语言序列索引,形状 (batch_size, src_len) tgt: 目标语言序列索引,形状 (batch_size, tgt_len) src_mask: 源序列 padding mask (可选) tgt_mask: 目标序列因果掩码 + padding mask """ # 1. 编码器 src_emb = self.dropout(self.positional_encoding(self.encoder_embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32)))) memory = src_emb for layer in self.encoder_layers: memory = layer(memory, src_mask) # 2. 解码器 if tgt_mask is None: tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device) tgt_emb = self.dropout(self.positional_encoding(self.decoder_embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32)))) output = tgt_emb for layer in self.decoder_layers: output = layer(output, memory, src_mask, tgt_mask) # 3. 输出投影 logits = self.final_linear(output) return logits # 示例:定义一个微型 Transformer src_vocab_size = 10000 tgt_vocab_size = 10000 model = Transformer(src_vocab_size, tgt_vocab_size, d_model=128, num_heads=4, num_encoder_layers=2, num_decoder_layers=2, d_ff=512) batch_size = 4 src_len = 10 tgt_len = 12 src = torch.randint(0, src_vocab_size, (batch_size, src_len)) tgt = torch.randint(0, tgt_vocab_size, (batch_size, tgt_len)) logits = model(src, tgt) print(f"模型输出 logits 形状: {logits.shape}") # (batch_size, tgt_len, tgt_vocab_size) # 这代表了在目标序列每个位置,对目标词汇表中所有词的概率预测5. 训练与推理流程
5.1 训练任务:掩码语言建模与自回归语言建模
Transformer 的训练方式决定了其最终用途:
- 编码器-解码器架构(如原始 Transformer、T5):通常用于序列到序列任务,如翻译、摘要。训练时,编码器接收源序列,解码器以自回归方式(使用因果掩码)预测目标序列。
- 仅编码器架构(如 BERT):用于理解任务。采用掩码语言建模:随机遮盖输入序列中的一些词元,让模型根据上下文预测被遮盖的词。
- 仅解码器架构(如 GPT 系列):用于生成任务。采用自回归语言建模:给定前文,预测下一个词。训练时,整个序列作为输入,但使用因果掩码确保预测位置
i时只能看到位置< i的信息。
5.2 推理流程(以自回归生成为例)
仅解码器模型(如 GPT)的推理是一个循环过程:
- 给定一个起始标记(如
<bos>),输入模型。 - 模型输出下一个词的概率分布。
- 根据某种策略(如贪婪搜索、束搜索、采样)从分布中选择一个词。
- 将选中的词追加到输入序列末尾,作为新的输入。
- 重复步骤 2-4,直到生成结束标记(如
<eos>)或达到最大长度。
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol, device): """贪婪解码示例""" model.eval() src = src.to(device) src_mask = src_mask.to(device) # 编码器前向传播 memory = model.encode(src, src_mask) # 初始化解码器输入为起始符号 ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device) for i in range(max_len-1): # 为当前生成的序列生成因果掩码 tgt_mask = model.generate_square_subsequent_mask(ys.size(1)).to(device) # 解码器前向传播 out = model.decode(ys, memory, src_mask, tgt_mask) # 获取最后一个位置的 logits 并预测下一个词 prob = model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.item() ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) if next_word == end_symbol: break return ys # 注意:此处的 model.encode, model.decode, model.generator 需要在 Transformer 类中实现相应方法。 # 实际中,我们通常直接调用 model(src, tgt) 并手动管理推理循环。6. 关键变体与优化
原始的 Transformer 架构是基石,后续研究提出了许多重要的变体和优化。
6.1 位置编码的演进
- 相对位置编码:原始正弦编码是绝对位置编码。相对位置编码(如 Transformer-XL、T5 使用的)让模型更关注词元之间的相对距离,而非绝对位置,在处理长文本时泛化能力更强。
- 旋转位置编码:RoPE 将位置信息通过旋转矩阵融入查询和键向量中,在保持相对位置信息的同时,被证明对长上下文扩展更友好,被 LLaMA、GPT-NeoX 等模型广泛采用。
6.2 注意力机制的优化
- 稀疏注意力:计算所有词元对之间的注意力复杂度是
O(n^2),对于长序列开销巨大。稀疏注意力(如 Longformer、BigBird)只计算每个词元与局部窗口内或少数全局词元之间的注意力,将复杂度降低到O(n)或O(n log n)。 - 线性注意力:通过核函数近似将 softmax 注意力转化为线性复杂度,如 Linformer、Performer。
- FlashAttention:一种 IO 感知的精确注意力算法,通过分块计算和重计算,显著减少 GPU 高带宽内存与片上 SRAM 之间的数据移动,极大提升了长序列注意力计算的速度和内存效率,已成为训练大模型的事实标准。
6.3 模型架构变体
- 仅编码器:如 BERT,适用于文本分类、命名实体识别等理解任务。
- 仅解码器:如 GPT 系列,适用于文本生成、代码生成等任务。
- 编码器-解码器:如 T5、BART,适用于翻译、摘要等 seq2seq 任务。
- 前缀语言模型:一种介于仅解码器和编码器-解码器之间的架构,将输入作为前缀,后续部分自回归生成。
6.4 推理优化技术
- KV 缓存:在自回归生成时,键
K和值V对于已经生成的 token 是固定不变的。KV 缓存将这些中间结果存储起来,避免在生成每个新 token 时重复计算,大幅提升推理速度。 - 多查询注意力 / 分组查询注意力:让多个注意力头共享同一套
K和V的投影权重,减少了推理时 KV 缓存的大小,从而支持更长的上下文或更大的批次,对推理速度有显著提升。 - 推测解码:使用一个更小、更快的“草稿模型”先生成多个候选 token,然后用原始大模型一次性并行验证这些候选。如果验证通过,则一次性接受多个 token,从而减少大模型的调用次数,加速生成。
7. 实战:使用 PyTorch 构建一个简化的 GPT 模型
让我们动手实现一个极简版的 GPT(仅解码器)来巩固理解。这个模型将包含词嵌入、位置编码、多个解码器层(带掩码自注意力)和一个输出层。
import torch import torch.nn as nn import torch.nn.functional as F import math class GPTDecoderLayer(nn.Module): """简化的 GPT 解码器层,只有掩码多头自注意力和前馈网络""" def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.self_attn = MultiHeadAttention(d_model, num_heads) # 复用之前定义的多头注意力 self.dropout1 = nn.Dropout(dropout) self.ln2 = nn.LayerNorm(d_model) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x, mask): # Pre-LN 结构 attn_output = self.self_attn(self.ln1(x), self.ln1(x), self.ln1(x), mask) x = x + self.dropout1(attn_output) ffn_output = self.ffn(self.ln2(x)) x = x + self.dropout2(ffn_output) return x class SimpleGPT(nn.Module): """一个极简的 GPT 模型""" def __init__(self, vocab_size, d_model=256, num_heads=8, num_layers=6, d_ff=1024, max_seq_len=512, dropout=0.1): super().__init__() self.token_embedding = nn.Embedding(vocab_size, d_model) self.position_embedding = nn.Embedding(max_seq_len, d_model) # 使用可学习的位置嵌入 self.dropout = nn.Dropout(dropout) self.layers = nn.ModuleList([ GPTDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers) ]) self.ln_f = nn.LayerNorm(d_model) # 最后的层归一化 self.lm_head = nn.Linear(d_model, vocab_size, bias=False) # 语言模型头 # 权重绑定:语言模型头的权重与词嵌入层共享(常见做法,可减少参数) self.lm_head.weight = self.token_embedding.weight self.max_seq_len = max_seq_len self.d_model = d_model self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): """ idx: 输入 token 索引,形状 (batch_size, seq_len) targets: 目标 token 索引,用于计算损失,形状同 idx """ device = idx.device b, t = idx.size() assert t <= self.max_seq_len, f"序列长度 {t} 超过了最大长度 {self.max_seq_len}" # 1. 词嵌入 + 位置嵌入 tok_emb = self.token_embedding(idx) # (b, t, d_model) pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # (1, t) pos_emb = self.position_embedding(pos) # (1, t, d_model) x = self.dropout(tok_emb + pos_emb) # 2. 生成因果掩码 causal_mask = torch.tril(torch.ones(t, t, device=device)).view(1, 1, t, t) # (1, 1, t, t) # 3. 通过所有解码器层 for layer in self.layers: x = layer(x, causal_mask) # 4. 最终层归一化和投影到词汇表 x = self.ln_f(x) logits = self.lm_head(x) # (b, t, vocab_size) loss = None if targets is not None: # 计算交叉熵损失,忽略 padding 等操作此处省略 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) return logits, loss def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """自回归生成文本""" self.eval() for _ in range(max_new_tokens): # 如果上下文太长,裁剪到最大长度(一种简单的处理方式) idx_cond = idx if idx.size(1) <= self.max_seq_len else idx[:, -self.max_seq_len:] # 前向传播,获取最后一个时间步的 logits logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature # (batch_size, vocab_size) # 可选:top-k 采样 if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # 从分布中采样 probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) # 将采样结果拼接到序列中 idx = torch.cat((idx, idx_next), dim=1) return idx # 示例:初始化模型并尝试生成 vocab_size = 1000 # 假设词汇表大小 model = SimpleGPT(vocab_size=vocab_size, d_model=128, num_heads=4, num_layers=4, d_ff=512, max_seq_len=256) print(f"模型参数量: {sum(p.numel() for p in model.parameters())/1e6:.2f} M") # 模拟一个 batch 的数据 batch_size = 2 seq_len = 10 input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) logits, loss = model(input_ids) print(f"输入形状: {input_ids.shape}") print(f"输出 logits 形状: {logits.shape}") # (2, 10, 1000) # 尝试生成(由于是随机初始化的模型,输出无意义) start_tokens = torch.randint(0, vocab_size, (batch_size, 1)) generated = model.generate(start_tokens, max_new_tokens=5, temperature=1.0, top_k=50) print(f"生成结果形状: {generated.shape}") print(f"示例生成序列: {generated[0]}")8. 常见问题与排查思路
在理解和实现 Transformer 时,你可能会遇到以下问题:
| 问题现象 | 可能原因 | 解决思路 |
|---|---|---|
| 训练不稳定,损失 NaN | 学习率过高;梯度爆炸;层归一化或残差连接实现有误;激活函数问题。 | 使用学习率预热;使用梯度裁剪;检查 Pre-LN/Post-LN 实现是否正确;尝试 GELU/SiLU 等更平滑的激活函数。 |
| 模型无法收敛或收敛慢 | 学习率不合适;模型初始化不当;数据预处理有问题;任务过于复杂。 | 进行学习率搜索;使用 Xavier/Glorot 或 Kaiming/He 初始化;检查数据标签和分词是否正确;尝试更简单的任务或增加数据。 |
| 推理时生成重复或无意义内容 | 采样策略问题(温度过低导致贪婪,温度过高导致随机);模型训练不足;存在重复性惩罚未设置。 | 调整temperature和top_p/top_k参数;增加训练步数;在生成时加入重复惩罚(repetition_penalty)。 |
| 处理长序列时内存溢出 (OOM) | 注意力矩阵(seq_len, seq_len)过大,消耗O(n^2)内存。 | 使用稀疏注意力(如 Longformer);使用 FlashAttention(如果框架支持);增加梯度检查点;减少批次大小或序列长度。 |
| 位置编码外推性差 | 使用绝对正弦位置编码的模型,在推理时遇到比训练时更长的序列,性能下降。 | 使用相对位置编码(如 RoPE、ALiBi);在训练时使用更长的上下文进行微调。 |
| KV 缓存导致推理错误 | 缓存未正确更新或重置;在生成不同序列时缓存混用。 | 确保在开始生成新序列时清空 KV 缓存;检查缓存张量的形状与当前生成步数是否匹配。 |
9. 最佳实践与工程建议
- 从预训练模型开始:除非有特定研究目的,否则不要从头开始训练大型 Transformer。利用 Hugging Face
transformers库加载 BERT、GPT-2、T5 等预训练模型进行微调,这是最高效的方式。 - 注意计算资源:Transformer 模型,尤其是大模型,对 GPU 显存要求很高。训练时注意使用混合精度训练、梯度累积、模型并行、数据并行等技术来优化资源使用。
- 使用现代库和优化器:优先使用 PyTorch 或 TensorFlow 等成熟框架,并搭配优化器如 AdamW,并配合学习率调度器(如带热身的余弦衰减)。
- 监控训练过程:密切关注训练损失和验证损失曲线,使用 WandB 或 TensorBoard 等工具进行可视化。早停法可以防止过拟合。
- 理解你的数据:Tokenizer 的选择(如 BPE、WordPiece、SentencePiece)对模型性能影响巨大。确保分词方式与你的任务和语言匹配。
- 生产环境部署优化:推理时,利用模型量化、动态批处理、持续批处理、FlashAttention、PagedAttention(vLLM)等技术来降低延迟、提高吞吐量。
- 安全与伦理:Transformer 模型可能生成有偏见、有害或不准确的内容。在部署前,必须进行全面的评估、红队测试,并考虑加入内容过滤和安全层。
Transformer 的原理虽然源于一篇学术论文,但其影响早已遍及工业界的每一个角落。从理解它的核心——自注意力机制开始,到掌握其完整的编码器-解码器架构,再到熟悉各种高效的变体和优化技术,这条学习路径将为你打开通往现代人工智能核心的大门。希望这篇近万字的详解能成为你探索 Transformer 世界的一块坚实基石。动手运行文中的代码,修改参数,观察输出变化,是理解这一切的最佳方式。
🚀 30+款热门AI模型一站整合,DeepSeek/GLM/Qwen 随心用,限时 5 折。 👉 点击领海量免费额度