1. 前言
上一篇我们已经把BERT 预训练数据这部分理顺了,知道了一条完整样本通常会包含:
token_idssegmentsvalid_lenpred_positionsmlm_weightsmlm_labelsnsp_label
到这里,模型也有了,数据也有了,
接下来就到了真正把两者接起来的时候:
BERT 预训练代码
这一节的核心就是把下面这几件事完整串起来:
BERT 模型前向传播
MLM 损失怎么计算
NSP 损失怎么计算
两个损失怎么合并
训练循环怎么写
如果一句话概括这一节的灵魂,那就是:
让 BERT 同时学会“填空”和“判断句子关系”。
2. BERT 预训练到底在训练什么
BERT 预训练不是单一目标,
而是两个任务一起训练:
第一,MLM
让模型预测被遮住的 token。
第二,NSP
让模型判断第二句是不是第一句的真实后续。
所以 BERT 预训练代码的核心,不是“写一个 loss”,
而是:
同时计算两个任务的 loss,再联合优化。
这也是它和很多普通 NLP 训练代码最明显的区别之一。
3. 训练输入通常长什么样
在进入代码前,先把 batch 输入想清楚。
一个 batch 里通常会拿到:
tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y它们分别表示:
tokens_X
真正送进 BERT 的 token id 序列。
segments_X
句子 A/B 的 segment 标记。
valid_lens_x
每条样本的有效长度,用于 mask padding。
pred_positions_X
被选中的 MLM 预测位置。
mlm_weights_X
哪些 MLM 位置真实有效,哪些只是 pad 出来的占位。
mlm_Y
MLM 的真实标签。
nsp_y
NSP 的二分类标签。
所以你会发现,BERT 训练一个 batch 的输入,远比普通分类任务复杂。
4. BERT 前向传播时会输出什么
前面在BERT代码那一节里,我们已经搭好了总模型,
它的前向传播通常会返回:
encoded_X, mlm_Y_hat, nsp_Y_hat这里:
encoded_X
整段输入序列的上下文化表示。
mlm_Y_hat
MLM 任务在 mask 位置上的预测结果。
nsp_Y_hat
NSP 任务的分类预测结果。
而在预训练阶段,我们真正关心的主要就是后两个:
mlm_Y_hatnsp_Y_hat
因为这两个才直接参与 loss 计算。
5. MLM 损失为什么不能直接普通算
因为 MLM 不是对整条序列所有位置都计算损失,
它只对:
被选中的 mask 位置
计算损失。
同时,由于不同样本被 mask 的 token 数量可能不同,
为了 batch 对齐,pred_positions和mlm_labels常常被 pad 到同样长度。
这就意味着:
某些位置是真实 MLM 目标
某些位置只是补齐占位
所以 MLM loss 不能简单一股脑全算,
而必须借助:
mlm_weights来屏蔽无效位置。
6. MLM 损失通常怎么写
李沐这里常见会写一个辅助函数,例如:
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y): _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)这里先调用模型,得到 MLM 和 NSP 的预测输出。
然后再分别计算两个损失。
7. 为什么valid_lens_x.reshape(-1)常出现
因为有时valid_lens_x在 batch 中的形状不是最理想的一维向量,
而模型内部注意力 mask 通常希望拿到的是:
(batch_size,)的一维有效长度。
所以常见会写:
valid_lens_x.reshape(-1)确保它变成一维。
这属于一个很常见的张量整理细节。
8. MLM loss 的核心计算怎么写
继续往下,常见写法类似:
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) mlm_l = (mlm_l * mlm_weights_X.reshape(-1, 1)).sum() / (mlm_weights_X.sum() + 1e-8)这两行特别关键。
我们逐步拆开看。
9. 为什么要reshape(-1, vocab_size)
因为交叉熵损失通常要求输入预测形状是:
(样本数, 类别数)而 MLM 预测mlm_Y_hat原始通常是:
(batch_size, num_pred_positions, vocab_size)所以要先把前两个维度合并,变成:
(batch_size * num_pred_positions, vocab_size)这样每个被预测位置,都被当成一个独立分类样本。
同理,标签mlm_Y也会 reshape 成一维:
(batch_size * num_pred_positions,)这就方便统一计算交叉熵。
10. 为什么还要乘mlm_weights_X
因为不是所有num_pred_positions都是真实 mask 位置。
有些只是为了对齐 batch 长度而 pad 出来的占位。
这些位置不该对损失产生影响。
所以:
mlm_l * mlm_weights_X本质上是在做:
保留真实 MLM 目标位置的损失,屏蔽 pad 出来的无效位置
这和前面序列任务里用 valid length 屏蔽<pad>的思路完全一致。
11. 为什么最后要除以mlm_weights_X.sum()
因为我们想要的是:
有效 MLM 位置上的平均损失
而不是简单把所有位置损失加起来。
所以通常会写成:
sum(valid losses) / number_of_valid_positions也就是:
(mlm_l * weights).sum() / weights.sum()这样不同 batch 即使有效 mask 数量不同,loss 规模也更稳定。
12. NSP 损失为什么更简单
相比 MLM,NSP 是一个很标准的二分类任务。
所以它的 loss 通常直接写成:
nsp_l = loss(nsp_Y_hat, nsp_y)如果loss设置为不做 reduction 的交叉熵,
这里最终通常再求个均值:
nsp_l = nsp_l.mean()因为 NSP 不需要像 MLM 那样按位置 mask。
每条样本就是一个标准二分类样本:
是下一句
不是下一句
所以计算起来简单很多。
13. 为什么最终总损失是两个任务 loss 相加
常见写法如下:
l = mlm_l + nsp_l原因很自然:
BERT 预训练本来就是一个多任务学习问题。
模型共享同一个编码器,同时服务于:
MLM
NSP
所以训练时就把两个任务的损失都算上,
一起反向传播。
这等价于告诉模型:
你既要学会利用上下文填空,也要学会判断句子关系。
这种联合训练正是 BERT 预训练的核心。
14. 为什么两个任务要共享同一个编码器
因为 BERT 想学到的是:
通用语言表示
而不是:
一个专门为 MLM 服务的编码器
一个专门为 NSP 服务的编码器
共享编码器的好处在于:
第一,语言知识能统一沉淀
同一套表示同时吸收词级和句级监督信号。
第二,参数更高效
不需要为每个任务单独训练一大套模型。
第三,更符合预训练目标
预训练就是希望学一个可以迁移到很多任务上的公共底座。
所以 MLM 和 NSP 本质上是:
两个任务头,共同打磨一个共享语言编码器
15. 一个完整的辅助函数通常怎么返回
前面那个_get_batch_loss_bert函数,常见最终写法类似:
return mlm_l, nsp_l, l也就是同时返回:
MLM loss
NSP loss
总 loss
为什么要分开返回?
因为训练过程中我们不仅想优化总损失,
往往还想监控:
MLM 学得怎么样
NSP 学得怎么样
这有助于观察训练是否正常,例如:
MLM loss 是否在下降
NSP 是否过快饱和
两个任务是否失衡
所以分别记录是很有必要的。
16. BERT 训练循环通常怎么写
训练循环的主线其实并不神秘,
和普通深度学习训练大体一样:
取一个 batch
前向传播
计算 MLM / NSP loss
反向传播
更新参数
记录指标
常见伪代码可以写成:
for tokens_X, segments_X, valid_lens_x, pred_positions_X, \ mlm_weights_X, mlm_Y, nsp_y in train_iter: trainer.zero_grad() mlm_l, nsp_l, l = _get_batch_loss_bert( net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y ) l.backward() trainer.step()所以你会发现:
BERT 训练循环和普通训练循环最大的区别,不在外壳,而在 batch 更复杂、loss 更复杂。
17. 为什么 BERT 预训练也常常需要梯度裁剪
虽然这一节里不一定每份代码都写得很展开,
但在实践里,BERT 预训练同样可能需要:
梯度裁剪
学习率 warmup
权重衰减
更稳定的优化器(如 Adam)
原因很简单:
模型大
层数深
自注意力结构复杂
训练目标多
所以 BERT 训练虽然思路清晰,
但工程上往往比前面 RNN、Seq2Seq 更讲究训练细节。
18. 训练时通常怎么监控指标
BERT 预训练里,一般至少会监控这几个量:
mlm_loss
表示填空任务学得怎么样。
nsp_loss
表示句对判断学得怎么样。
total_loss
总体优化目标。
有时还会进一步看:
MLM 准确率
NSP 准确率
不过在教学实现里,loss 通常是最基础也最重要的指标。
19. 为什么 BERT 预训练通常很耗资源
这一点值得顺手说明一下。
因为 BERT 预训练同时具备这些特点:
输入是整段序列
主体是多层 Transformer Encoder
自注意力复杂度随序列长度平方增长
还要做两个预训练任务
数据规模通常很大
所以它比前面很多小模型训练都要重得多。
也正因为如此,教学代码里通常会用:
较小模型
较短序列
较小语料
来帮助你先理解流程。
这不是“BERT 很简单”,
而是为了让你先看懂机制。
20. 这一节最该掌握什么
如果从学习重点来看,最关键的是下面几件事。
20.1 明白 BERT 预训练是双任务联合训练
不是只算一个 MLM loss。
20.2 看懂 MLM loss 为什么需要mlm_weights
这是处理 pad mask 位置的关键。
20.3 看懂 NSP loss 为什么更像标准分类任务
因为它本来就是句级二分类。
20.4 明白总 loss 为什么是两者相加
这是共享编码器多任务学习的核心。
20.5 理解训练循环本身并不神秘
难点主要在样本结构和 loss 组织方式。
21. 这一节和前后内容怎么衔接
这一节刚好把 BERT 这一段的前几节完整串起来了。
前面:BERT代码
已经有模型主体。
前面:BERT预训练数据代码
已经有训练样本组织方式。
这一节:BERT预训练代码
把模型和数据真正接起来训练。
而后面接着就是:
BERT微调
自然语言推理数据集
BERT微调代码
也就是说,预训练完成后,下一步就是:
如何把预训练好的 BERT 用到具体下游任务上。
这正是现代 NLP 的完整主线。
22. 本节总结
这一节我们学习了 BERT 预训练代码,核心内容可以总结为以下几点。
22.1 BERT 预训练同时优化 MLM 和 NSP 两个任务
这是原始 BERT 预训练范式的核心。
22.2 MLM loss 只在被选中的 mask 位置上计算
因此需要借助pred_positions和mlm_weights。
22.3 NSP loss 是标准句级二分类损失
通常基于[CLS]表示进行判断。
22.4 总损失通常是 MLM loss 和 NSP loss 的和
通过共享编码器实现多任务联合训练。
22.5 训练循环本质和普通深度学习一致
只是输入结构和 loss 组织更复杂。
23. 学习感悟
这一节特别有价值,因为它会让你真正看到:
BERT 的强大,不只是模型结构先进,
还在于它把“训练目标”和“数据构造”设计得非常系统。
很多时候,大家谈 BERT 会只盯着 Transformer,
但真正把代码串起来之后你会发现:
模型
数据
目标
这三件事是紧密耦合的。
也正因为它们配合得好,BERT 才能在大规模无标注语料上学出这么强的通用表示。