动手学深度学习——BERT预训练代码
1. 前言上一篇我们已经把BERT 预训练数据这部分理顺了知道了一条完整样本通常会包含token_idssegmentsvalid_lenpred_positionsmlm_weightsmlm_labelsnsp_label到这里模型也有了数据也有了接下来就到了真正把两者接起来的时候BERT 预训练代码这一节的核心就是把下面这几件事完整串起来BERT 模型前向传播MLM 损失怎么计算NSP 损失怎么计算两个损失怎么合并训练循环怎么写如果一句话概括这一节的灵魂那就是让 BERT 同时学会“填空”和“判断句子关系”。2. BERT 预训练到底在训练什么BERT 预训练不是单一目标而是两个任务一起训练第一MLM让模型预测被遮住的 token。第二NSP让模型判断第二句是不是第一句的真实后续。所以 BERT 预训练代码的核心不是“写一个 loss”而是同时计算两个任务的 loss再联合优化。这也是它和很多普通 NLP 训练代码最明显的区别之一。3. 训练输入通常长什么样在进入代码前先把 batch 输入想清楚。一个 batch 里通常会拿到tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y它们分别表示tokens_X真正送进 BERT 的 token id 序列。segments_X句子 A/B 的 segment 标记。valid_lens_x每条样本的有效长度用于 mask padding。pred_positions_X被选中的 MLM 预测位置。mlm_weights_X哪些 MLM 位置真实有效哪些只是 pad 出来的占位。mlm_YMLM 的真实标签。nsp_yNSP 的二分类标签。所以你会发现BERT 训练一个 batch 的输入远比普通分类任务复杂。4. BERT 前向传播时会输出什么前面在BERT代码那一节里我们已经搭好了总模型它的前向传播通常会返回encoded_X, mlm_Y_hat, nsp_Y_hat这里encoded_X整段输入序列的上下文化表示。mlm_Y_hatMLM 任务在 mask 位置上的预测结果。nsp_Y_hatNSP 任务的分类预测结果。而在预训练阶段我们真正关心的主要就是后两个mlm_Y_hatnsp_Y_hat因为这两个才直接参与 loss 计算。5. MLM 损失为什么不能直接普通算因为 MLM 不是对整条序列所有位置都计算损失它只对被选中的 mask 位置计算损失。同时由于不同样本被 mask 的 token 数量可能不同为了 batch 对齐pred_positions和mlm_labels常常被 pad 到同样长度。这就意味着某些位置是真实 MLM 目标某些位置只是补齐占位所以 MLM loss 不能简单一股脑全算而必须借助mlm_weights来屏蔽无效位置。6. MLM 损失通常怎么写李沐这里常见会写一个辅助函数例如def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y): _, mlm_Y_hat, nsp_Y_hat net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)这里先调用模型得到 MLM 和 NSP 的预测输出。然后再分别计算两个损失。7. 为什么valid_lens_x.reshape(-1)常出现因为有时valid_lens_x在 batch 中的形状不是最理想的一维向量而模型内部注意力 mask 通常希望拿到的是(batch_size,)的一维有效长度。所以常见会写valid_lens_x.reshape(-1)确保它变成一维。这属于一个很常见的张量整理细节。8. MLM loss 的核心计算怎么写继续往下常见写法类似mlm_l loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) mlm_l (mlm_l * mlm_weights_X.reshape(-1, 1)).sum() / (mlm_weights_X.sum() 1e-8)这两行特别关键。我们逐步拆开看。9. 为什么要reshape(-1, vocab_size)因为交叉熵损失通常要求输入预测形状是(样本数, 类别数)而 MLM 预测mlm_Y_hat原始通常是(batch_size, num_pred_positions, vocab_size)所以要先把前两个维度合并变成(batch_size * num_pred_positions, vocab_size)这样每个被预测位置都被当成一个独立分类样本。同理标签mlm_Y也会 reshape 成一维(batch_size * num_pred_positions,)这就方便统一计算交叉熵。10. 为什么还要乘mlm_weights_X因为不是所有num_pred_positions都是真实 mask 位置。有些只是为了对齐 batch 长度而 pad 出来的占位。这些位置不该对损失产生影响。所以mlm_l * mlm_weights_X本质上是在做保留真实 MLM 目标位置的损失屏蔽 pad 出来的无效位置这和前面序列任务里用 valid length 屏蔽pad的思路完全一致。11. 为什么最后要除以mlm_weights_X.sum()因为我们想要的是有效 MLM 位置上的平均损失而不是简单把所有位置损失加起来。所以通常会写成sum(valid losses) / number_of_valid_positions也就是(mlm_l * weights).sum() / weights.sum()这样不同 batch 即使有效 mask 数量不同loss 规模也更稳定。12. NSP 损失为什么更简单相比 MLMNSP 是一个很标准的二分类任务。所以它的 loss 通常直接写成nsp_l loss(nsp_Y_hat, nsp_y)如果loss设置为不做 reduction 的交叉熵这里最终通常再求个均值nsp_l nsp_l.mean()因为 NSP 不需要像 MLM 那样按位置 mask。每条样本就是一个标准二分类样本是下一句不是下一句所以计算起来简单很多。13. 为什么最终总损失是两个任务 loss 相加常见写法如下l mlm_l nsp_l原因很自然BERT 预训练本来就是一个多任务学习问题。模型共享同一个编码器同时服务于MLMNSP所以训练时就把两个任务的损失都算上一起反向传播。这等价于告诉模型你既要学会利用上下文填空也要学会判断句子关系。这种联合训练正是 BERT 预训练的核心。14. 为什么两个任务要共享同一个编码器因为 BERT 想学到的是通用语言表示而不是一个专门为 MLM 服务的编码器一个专门为 NSP 服务的编码器共享编码器的好处在于第一语言知识能统一沉淀同一套表示同时吸收词级和句级监督信号。第二参数更高效不需要为每个任务单独训练一大套模型。第三更符合预训练目标预训练就是希望学一个可以迁移到很多任务上的公共底座。所以 MLM 和 NSP 本质上是两个任务头共同打磨一个共享语言编码器15. 一个完整的辅助函数通常怎么返回前面那个_get_batch_loss_bert函数常见最终写法类似return mlm_l, nsp_l, l也就是同时返回MLM lossNSP loss总 loss为什么要分开返回因为训练过程中我们不仅想优化总损失往往还想监控MLM 学得怎么样NSP 学得怎么样这有助于观察训练是否正常例如MLM loss 是否在下降NSP 是否过快饱和两个任务是否失衡所以分别记录是很有必要的。16. BERT 训练循环通常怎么写训练循环的主线其实并不神秘和普通深度学习训练大体一样取一个 batch前向传播计算 MLM / NSP loss反向传播更新参数记录指标常见伪代码可以写成for tokens_X, segments_X, valid_lens_x, pred_positions_X, \ mlm_weights_X, mlm_Y, nsp_y in train_iter: trainer.zero_grad() mlm_l, nsp_l, l _get_batch_loss_bert( net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y ) l.backward() trainer.step()所以你会发现BERT 训练循环和普通训练循环最大的区别不在外壳而在 batch 更复杂、loss 更复杂。17. 为什么 BERT 预训练也常常需要梯度裁剪虽然这一节里不一定每份代码都写得很展开但在实践里BERT 预训练同样可能需要梯度裁剪学习率 warmup权重衰减更稳定的优化器如 Adam原因很简单模型大层数深自注意力结构复杂训练目标多所以 BERT 训练虽然思路清晰但工程上往往比前面 RNN、Seq2Seq 更讲究训练细节。18. 训练时通常怎么监控指标BERT 预训练里一般至少会监控这几个量mlm_loss表示填空任务学得怎么样。nsp_loss表示句对判断学得怎么样。total_loss总体优化目标。有时还会进一步看MLM 准确率NSP 准确率不过在教学实现里loss 通常是最基础也最重要的指标。19. 为什么 BERT 预训练通常很耗资源这一点值得顺手说明一下。因为 BERT 预训练同时具备这些特点输入是整段序列主体是多层 Transformer Encoder自注意力复杂度随序列长度平方增长还要做两个预训练任务数据规模通常很大所以它比前面很多小模型训练都要重得多。也正因为如此教学代码里通常会用较小模型较短序列较小语料来帮助你先理解流程。这不是“BERT 很简单”而是为了让你先看懂机制。20. 这一节最该掌握什么如果从学习重点来看最关键的是下面几件事。20.1 明白 BERT 预训练是双任务联合训练不是只算一个 MLM loss。20.2 看懂 MLM loss 为什么需要mlm_weights这是处理 pad mask 位置的关键。20.3 看懂 NSP loss 为什么更像标准分类任务因为它本来就是句级二分类。20.4 明白总 loss 为什么是两者相加这是共享编码器多任务学习的核心。20.5 理解训练循环本身并不神秘难点主要在样本结构和 loss 组织方式。21. 这一节和前后内容怎么衔接这一节刚好把 BERT 这一段的前几节完整串起来了。前面BERT代码已经有模型主体。前面BERT预训练数据代码已经有训练样本组织方式。这一节BERT预训练代码把模型和数据真正接起来训练。而后面接着就是BERT微调自然语言推理数据集BERT微调代码也就是说预训练完成后下一步就是如何把预训练好的 BERT 用到具体下游任务上。这正是现代 NLP 的完整主线。22. 本节总结这一节我们学习了 BERT 预训练代码核心内容可以总结为以下几点。22.1 BERT 预训练同时优化 MLM 和 NSP 两个任务这是原始 BERT 预训练范式的核心。22.2 MLM loss 只在被选中的 mask 位置上计算因此需要借助pred_positions和mlm_weights。22.3 NSP loss 是标准句级二分类损失通常基于[CLS]表示进行判断。22.4 总损失通常是 MLM loss 和 NSP loss 的和通过共享编码器实现多任务联合训练。22.5 训练循环本质和普通深度学习一致只是输入结构和 loss 组织更复杂。23. 学习感悟这一节特别有价值因为它会让你真正看到BERT 的强大不只是模型结构先进还在于它把“训练目标”和“数据构造”设计得非常系统。很多时候大家谈 BERT 会只盯着 Transformer但真正把代码串起来之后你会发现模型数据目标这三件事是紧密耦合的。也正因为它们配合得好BERT 才能在大规模无标注语料上学出这么强的通用表示。