0%

BERT研究

简介

本文简要解析 BERT 的预训练辅助任务(Auxiliary Task)。

预训练辅助任务

BERT 通过两个辅助任务训练语言模型:Masked LM(MLM)与 Next Sentence Prediction(NSP)。

  • MLM:随机 mask 15% 的输入(token),模型需要通过 context 信息还原被 masked 的输入。
  • NSP:随机生成句子对,模型需要判断句子对是否连续(next sentence)。

在训练过程中,MLM 与 NSP 的 loss 是同时计算的,属于多任务学习。

MLM

在 BERT 之前,LM 通常是单向的,常见做法是分别训练正向与反向的 LM,然后再做一个 ensemble 得到上下文相关表征(context dependent representation)。这种做法会有信息缺失与标注偏差的问题 。MLM 的意义在于,可以使 BERT 作为单模型学习到上下文相关的表征,并能更充分地利用双向的信息。

论文里强调了设计 MLM 任务需要注意的问题:

The first is that we are creating a mismatch between pre-training and fine-tuning, since the [MASK] token is never seen during fine-tuning

为了解决这个问题,MLM 采取以下策略 mask 15% 的输入(token):

  • 80% 的概率,把输入替换为 [MASK]
  • 10% 的概率,把输入替换为随机的 token。
  • 10% 的概率,维持输入不变。

这篇博文 提供了一个解释:

  • 如果把 100% 的输入替换为 [MASK]:模型会偏向为 [MASK] 输入建模,而不会学习到 non-masked 输入的表征。
  • 如果把 90% 的输入替换为 [MASK]、10% 的输入替换为随机 token:模型会偏向认为 non-masked 输入是错的。
  • 如果把 90% 的输入替换为 [MASK]、维持 10% 的输入不变:模型会偏向直接复制 non-masked 输入的上下文无关表征。
  • 所以,为了使模型可以学习到相对有效的上下文相关表征,需要以 1:1 的比例使用两种策略处理 non-masked 输入。论文提及,随机替换的输入只占整体的 1.5%,似乎不会对最终效果有影响(模型有足够的容错余量)。

NSP

句子级别表征(sentence-level representation)对于某些下游任务是很有用的。NSP 使 BERT 可以从大规模语料中学习句子级别表征、句子关系的知识。NSP 的做法是:

  1. 从语料中提取两个句子 AB ,50% 的概率 BA 的下一个句子,50% 的概率 B 是一个随机选取的句子,以此为标注训练分类器。
  2. AB 打包成一个序列(sequence):[CLS] A [SEP] B [SEP]
  3. 生成区间标识(segment labels),标识序列中 AB 的位置。[CLS] A [SEP] 的区域设为 0B [SEP] 的区域设为 10, 0..., 0, 1..., 1
  4. 将序列与区间标识输入到模型,取 [CLS] 的表征训练 NSP 分类器。

预处理流程

官方实现 的 Pre-training with BERT 小节简述了预处理的执行方式。通过阅读源码,可以将预处理逻辑划分为以下阶段:

  1. 数据准备

  2. 读取句子与 Tokenization

  3. 构建序列

  4. Masking

  5. 构建 Instance 与导出

数据准备

  1. 准备一个或多个 txt 文本,满足以下格式(示例 ):

    1. 每一行,如果非空,存储一个句子。每个句子使用 whitespace 切词。
    2. 使用空行表示文档结尾。空行的目的是禁止生成跨文档的序列。
  2. 准备 WordPiece 词典 。通常情况下,可以直接复用预训练模型的词典。

读取句子与 Tokenization

实现见 create_training_instances 函数 。

逻辑:

  1. 读取一个文档的所有句子,whitespace 切分后再基于词典做 WordPiece 切分 。
  2. 基于 (1) 构建嵌套 list all_documents,结构是 [doc_num, sent_num, token_num]
  3. Shuffle all_documents

构建序列

“序列(Sequence)”是 BERT 论文定义的概念,包含一个句子对(sentence pair)。基于论文描述与代码逻辑,我认为使用 “句子集合对” 描述更加合适:

To generate each training input sequence, we sample two spans of text from the corpus, which we refer to as “sentences” even though they are typically much longer than single sentences (but can be shorter also). The first sentence receives the A embedding and the second receives the B embedding. 50% of the time B is the actual next sentence that follows A and 50% of the time it is a random sentence, which is done for the “next sentence prediction” task. They are sampled such that the combined length is ≤ 512 tokens.

为了防止误解,后续我将使用 “句子集合 A” 与 “句子集合 B” 描述序列中这两个区域。

逻辑:

  1. create_instances_from_document 函数 包含了构建序列的所有逻辑。序列的构建是以文档为单位的,脚本仅会对每篇文档执行一次此函数,从中生成若干序列。
  2. 序列长度 target_seq_lengthmax_seq_length [default: 128]short_seq_prob [default: 0.1] 决定:
    1. max_num_tokens = max_seq_length - 3 (考虑特殊 token [CLS], 2 * [SEP])。
    2. 1 - short_seq_prob 的概率,将序列长度设为 max_num_tokens
    3. short_seq_prob 的概率,随机选取 [2, max_num_tokens] 为序列长度,目的是降低预训练与 fine-tuning 阶段序列长度不一致的问题 。
    4. 需要注意:一是这个序列长度的生效区域是单个文档;二是为了加速收敛 ,训练初期可能会选取较小的序列长度(如 128),训练后期再选取较大的序列长度(如 512),这种场景需要生成多批训练数据。
  3. 基于(2)的序列长度,生成序列:
    1. 维护一个全局 index i 标识当前文档中尚未处理的句子 。
    2. i 开始收集若干句子加入 current_chunk ,直到 tokens 的数目大于等于序列长度,或已收集到最后一个句子。
    3. 随机选取 current_chunk 的前 [1, len(current_chunk) - 1] 个句子作为句子集合 A
    4. 选取 句子集合 B
      1. 以 50% 的概率,选取非连续(non-next sentence)的句子集合 B。选取方式:随机选择一个除当前文档以外的文档,然后从中选取若干连续的句子,使 A 与 B 的长度恰好超过 target_seq_length
      2. 以 50% 的概率,选取连续(non-next sentence)的句子集合 B。选取方式:直接选取 current_chunk 的剩余部分。
      3. is_random_next 标记连续、非连续的随机结果,用于训练 NSP。
      4. 特殊情况:current_chunk 仅包含一个句子,这种情况强制选取非连续句子集合。
    5. 使用 truncate_seq_pair 裁剪句子集合 A 与 B,保证最终序列长度小于等于 max_num_tokens
    6. 合并句子集合 A 与 B 构建序列 [CLS] A [SEP] B [SEP] ,同时生成区间标识(segment id)0, 0..., 0, 1..., 1

Masking

在构建完序列之后,我们需要随机 mask 输入用于训练 MLM。

逻辑:

  1. create_masked_lm_predictions 函数 包含了 Masking 的所有逻辑。Masking 是以序列为单位的,每个序列只会被 masked 一次。
  2. Masking 的数目 num_to_predictmasked_lm_prob [default: 0.15]max_predictions_per_seq [default: 20] 决定 ,即随机选取 masked_lm_prob 占比的输入 mask,如果输入超过 max_predictions_per_seq 则只 mask max_predictions_per_seq 个 tokens。
  3. 对于每个需要 mask 的输入:
    1. 80% 的概率,把输入替换为 [MASK]
    2. 10% 的概率,把输入替换为随机的 token 。
    3. 10% 的概率,维持输入不变 。

构建 Instance 与导出

最后将上述处理结果打包,为每一个序列生成一个 Instance ,并导出到文件 。