动手学深度学习——BERT预训练代码
2026/4/15 13:33:29 网站建设 项目流程

1. 前言

上一篇我们已经把BERT 预训练数据这部分理顺了,知道了一条完整样本通常会包含:

  • token_ids

  • segments

  • valid_len

  • pred_positions

  • mlm_weights

  • mlm_labels

  • nsp_label

到这里,模型也有了,数据也有了,
接下来就到了真正把两者接起来的时候:

BERT 预训练代码

这一节的核心就是把下面这几件事完整串起来:

  • BERT 模型前向传播

  • MLM 损失怎么计算

  • NSP 损失怎么计算

  • 两个损失怎么合并

  • 训练循环怎么写

如果一句话概括这一节的灵魂,那就是:

让 BERT 同时学会“填空”和“判断句子关系”。


2. BERT 预训练到底在训练什么

BERT 预训练不是单一目标,
而是两个任务一起训练:

第一,MLM

让模型预测被遮住的 token。

第二,NSP

让模型判断第二句是不是第一句的真实后续。

所以 BERT 预训练代码的核心,不是“写一个 loss”,
而是:

同时计算两个任务的 loss,再联合优化。

这也是它和很多普通 NLP 训练代码最明显的区别之一。


3. 训练输入通常长什么样

在进入代码前,先把 batch 输入想清楚。

一个 batch 里通常会拿到:

tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y

它们分别表示:

tokens_X

真正送进 BERT 的 token id 序列。

segments_X

句子 A/B 的 segment 标记。

valid_lens_x

每条样本的有效长度,用于 mask padding。

pred_positions_X

被选中的 MLM 预测位置。

mlm_weights_X

哪些 MLM 位置真实有效,哪些只是 pad 出来的占位。

mlm_Y

MLM 的真实标签。

nsp_y

NSP 的二分类标签。

所以你会发现,BERT 训练一个 batch 的输入,远比普通分类任务复杂。


4. BERT 前向传播时会输出什么

前面在BERT代码那一节里,我们已经搭好了总模型,
它的前向传播通常会返回:

encoded_X, mlm_Y_hat, nsp_Y_hat

这里:

encoded_X

整段输入序列的上下文化表示。

mlm_Y_hat

MLM 任务在 mask 位置上的预测结果。

nsp_Y_hat

NSP 任务的分类预测结果。

而在预训练阶段,我们真正关心的主要就是后两个:

  • mlm_Y_hat

  • nsp_Y_hat

因为这两个才直接参与 loss 计算。


5. MLM 损失为什么不能直接普通算

因为 MLM 不是对整条序列所有位置都计算损失,
它只对:

被选中的 mask 位置

计算损失。

同时,由于不同样本被 mask 的 token 数量可能不同,
为了 batch 对齐,pred_positionsmlm_labels常常被 pad 到同样长度。

这就意味着:

  • 某些位置是真实 MLM 目标

  • 某些位置只是补齐占位

所以 MLM loss 不能简单一股脑全算,
而必须借助:

mlm_weights

来屏蔽无效位置。


6. MLM 损失通常怎么写

李沐这里常见会写一个辅助函数,例如:

def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y): _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)

这里先调用模型,得到 MLM 和 NSP 的预测输出。

然后再分别计算两个损失。


7. 为什么valid_lens_x.reshape(-1)常出现

因为有时valid_lens_x在 batch 中的形状不是最理想的一维向量,
而模型内部注意力 mask 通常希望拿到的是:

(batch_size,)

的一维有效长度。

所以常见会写:

valid_lens_x.reshape(-1)

确保它变成一维。

这属于一个很常见的张量整理细节。


8. MLM loss 的核心计算怎么写

继续往下,常见写法类似:

mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) mlm_l = (mlm_l * mlm_weights_X.reshape(-1, 1)).sum() / (mlm_weights_X.sum() + 1e-8)

这两行特别关键。

我们逐步拆开看。


9. 为什么要reshape(-1, vocab_size)

因为交叉熵损失通常要求输入预测形状是:

(样本数, 类别数)

而 MLM 预测mlm_Y_hat原始通常是:

(batch_size, num_pred_positions, vocab_size)

所以要先把前两个维度合并,变成:

(batch_size * num_pred_positions, vocab_size)

这样每个被预测位置,都被当成一个独立分类样本。

同理,标签mlm_Y也会 reshape 成一维:

(batch_size * num_pred_positions,)

这就方便统一计算交叉熵。


10. 为什么还要乘mlm_weights_X

因为不是所有num_pred_positions都是真实 mask 位置。

有些只是为了对齐 batch 长度而 pad 出来的占位。
这些位置不该对损失产生影响。

所以:

mlm_l * mlm_weights_X

本质上是在做:

保留真实 MLM 目标位置的损失,屏蔽 pad 出来的无效位置

这和前面序列任务里用 valid length 屏蔽<pad>的思路完全一致。


11. 为什么最后要除以mlm_weights_X.sum()

因为我们想要的是:

有效 MLM 位置上的平均损失

而不是简单把所有位置损失加起来。

所以通常会写成:

sum(valid losses) / number_of_valid_positions

也就是:

(mlm_l * weights).sum() / weights.sum()

这样不同 batch 即使有效 mask 数量不同,loss 规模也更稳定。


12. NSP 损失为什么更简单

相比 MLM,NSP 是一个很标准的二分类任务。

所以它的 loss 通常直接写成:

nsp_l = loss(nsp_Y_hat, nsp_y)

如果loss设置为不做 reduction 的交叉熵,
这里最终通常再求个均值:

nsp_l = nsp_l.mean()

因为 NSP 不需要像 MLM 那样按位置 mask。
每条样本就是一个标准二分类样本:

  • 是下一句

  • 不是下一句

所以计算起来简单很多。


13. 为什么最终总损失是两个任务 loss 相加

常见写法如下:

l = mlm_l + nsp_l

原因很自然:

BERT 预训练本来就是一个多任务学习问题。
模型共享同一个编码器,同时服务于:

  • MLM

  • NSP

所以训练时就把两个任务的损失都算上,
一起反向传播。

这等价于告诉模型:

你既要学会利用上下文填空,也要学会判断句子关系。

这种联合训练正是 BERT 预训练的核心。


14. 为什么两个任务要共享同一个编码器

因为 BERT 想学到的是:

通用语言表示

而不是:

  • 一个专门为 MLM 服务的编码器

  • 一个专门为 NSP 服务的编码器

共享编码器的好处在于:

第一,语言知识能统一沉淀

同一套表示同时吸收词级和句级监督信号。

第二,参数更高效

不需要为每个任务单独训练一大套模型。

第三,更符合预训练目标

预训练就是希望学一个可以迁移到很多任务上的公共底座。

所以 MLM 和 NSP 本质上是:

两个任务头,共同打磨一个共享语言编码器


15. 一个完整的辅助函数通常怎么返回

前面那个_get_batch_loss_bert函数,常见最终写法类似:

return mlm_l, nsp_l, l

也就是同时返回:

  • MLM loss

  • NSP loss

  • 总 loss

为什么要分开返回?

因为训练过程中我们不仅想优化总损失,
往往还想监控:

  • MLM 学得怎么样

  • NSP 学得怎么样

这有助于观察训练是否正常,例如:

  • MLM loss 是否在下降

  • NSP 是否过快饱和

  • 两个任务是否失衡

所以分别记录是很有必要的。


16. BERT 训练循环通常怎么写

训练循环的主线其实并不神秘,
和普通深度学习训练大体一样:

  1. 取一个 batch

  2. 前向传播

  3. 计算 MLM / NSP loss

  4. 反向传播

  5. 更新参数

  6. 记录指标

常见伪代码可以写成:

for tokens_X, segments_X, valid_lens_x, pred_positions_X, \ mlm_weights_X, mlm_Y, nsp_y in train_iter: trainer.zero_grad() mlm_l, nsp_l, l = _get_batch_loss_bert( net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y ) l.backward() trainer.step()

所以你会发现:

BERT 训练循环和普通训练循环最大的区别,不在外壳,而在 batch 更复杂、loss 更复杂。


17. 为什么 BERT 预训练也常常需要梯度裁剪

虽然这一节里不一定每份代码都写得很展开,
但在实践里,BERT 预训练同样可能需要:

  • 梯度裁剪

  • 学习率 warmup

  • 权重衰减

  • 更稳定的优化器(如 Adam)

原因很简单:

  • 模型大

  • 层数深

  • 自注意力结构复杂

  • 训练目标多

所以 BERT 训练虽然思路清晰,
但工程上往往比前面 RNN、Seq2Seq 更讲究训练细节。


18. 训练时通常怎么监控指标

BERT 预训练里,一般至少会监控这几个量:

mlm_loss

表示填空任务学得怎么样。

nsp_loss

表示句对判断学得怎么样。

total_loss

总体优化目标。

有时还会进一步看:

  • MLM 准确率

  • NSP 准确率

不过在教学实现里,loss 通常是最基础也最重要的指标。


19. 为什么 BERT 预训练通常很耗资源

这一点值得顺手说明一下。

因为 BERT 预训练同时具备这些特点:

  • 输入是整段序列

  • 主体是多层 Transformer Encoder

  • 自注意力复杂度随序列长度平方增长

  • 还要做两个预训练任务

  • 数据规模通常很大

所以它比前面很多小模型训练都要重得多。

也正因为如此,教学代码里通常会用:

  • 较小模型

  • 较短序列

  • 较小语料

来帮助你先理解流程。

这不是“BERT 很简单”,
而是为了让你先看懂机制。


20. 这一节最该掌握什么

如果从学习重点来看,最关键的是下面几件事。

20.1 明白 BERT 预训练是双任务联合训练

不是只算一个 MLM loss。

20.2 看懂 MLM loss 为什么需要mlm_weights

这是处理 pad mask 位置的关键。

20.3 看懂 NSP loss 为什么更像标准分类任务

因为它本来就是句级二分类。

20.4 明白总 loss 为什么是两者相加

这是共享编码器多任务学习的核心。

20.5 理解训练循环本身并不神秘

难点主要在样本结构和 loss 组织方式。


21. 这一节和前后内容怎么衔接

这一节刚好把 BERT 这一段的前几节完整串起来了。

前面:BERT代码

已经有模型主体。

前面:BERT预训练数据代码

已经有训练样本组织方式。

这一节:BERT预训练代码

把模型和数据真正接起来训练。

而后面接着就是:

  • BERT微调

  • 自然语言推理数据集

  • BERT微调代码

也就是说,预训练完成后,下一步就是:

如何把预训练好的 BERT 用到具体下游任务上。

这正是现代 NLP 的完整主线。


22. 本节总结

这一节我们学习了 BERT 预训练代码,核心内容可以总结为以下几点。

22.1 BERT 预训练同时优化 MLM 和 NSP 两个任务

这是原始 BERT 预训练范式的核心。

22.2 MLM loss 只在被选中的 mask 位置上计算

因此需要借助pred_positionsmlm_weights

22.3 NSP loss 是标准句级二分类损失

通常基于[CLS]表示进行判断。

22.4 总损失通常是 MLM loss 和 NSP loss 的和

通过共享编码器实现多任务联合训练。

22.5 训练循环本质和普通深度学习一致

只是输入结构和 loss 组织更复杂。


23. 学习感悟

这一节特别有价值,因为它会让你真正看到:

BERT 的强大,不只是模型结构先进,
还在于它把“训练目标”和“数据构造”设计得非常系统。

很多时候,大家谈 BERT 会只盯着 Transformer,
但真正把代码串起来之后你会发现:

  • 模型

  • 数据

  • 目标

这三件事是紧密耦合的。

也正因为它们配合得好,BERT 才能在大规模无标注语料上学出这么强的通用表示。

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询