借助15.8 節(jié)中實現(xiàn)的 BERT 模型和15.9 節(jié)中從 WikiText-2 數(shù)據(jù)集生成的預訓練示例 ,我們將在本節(jié)中在 WikiText-2 數(shù)據(jù)集上預訓練 BERT。
首先,我們將 WikiText-2 數(shù)據(jù)集加載為用于屏蔽語言建模和下一句預測的小批量預訓練示例。批量大小為 512,BERT 輸入序列的最大長度為 64。請注意,在原始 BERT 模型中,最大長度為 512。
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...
15.10.1。預訓練 BERT
原始 BERT 有兩個不同模型大小的版本 (Devlin et al. , 2018)。基礎模型(BERTBASE) 使用 12 層(Transformer 編碼器塊),具有 768 個隱藏單元(隱藏大?。┖?12 個自注意力頭。大模型(BERTLARGE) 使用 24 層,有 1024 個隱藏單元和 16 個自注意力頭。值得注意的是,前者有 1.1 億個參數(shù),而后者有 3.4 億個參數(shù)。為了便于演示,我們定義了一個小型 BERT,使用 2 層、128 個隱藏單元和 2 個自注意力頭。
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()
在定義訓練循環(huán)之前,我們定義了一個輔助函數(shù) _get_batch_loss_bert
。給定訓練示例的碎片,此函數(shù)計算掩碼語言建模和下一句預測任務的損失。請注意,BERT 預訓練的最終損失只是掩碼語言建模損失和下一句預測損失的總和。
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
segments_X, valid_lens_x,
pred_positions_X, mlm_weights_X,
mlm_Y, nsp_y):
# Forward pass
_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
valid_lens_x.reshape(-1),
pred_positions_X)
# Compute masked language model loss
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
mlm_weights_X.reshape(-1, 1)
mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
# Compute next sentence prediction loss
nsp_l = loss(nsp_Y_hat, nsp_y)
l = mlm_l + nsp_l
return mlm_l, nsp_l, l
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards,
mlm_Y_shards, nsp_y_shards):
mlm_ls, nsp_ls, ls = [], [], []
for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
nsp_y_shard) in zip(
tokens_X_shards, segments_X_shards, valid_lens_x_shards,
pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
nsp_y_shards):
# Forward pass
_, mlm_Y_hat, nsp_Y_hat = net(
tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
pred_positions_X_shard)
# Compute masked language model loss
mlm_l = loss(
mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
mlm_weights_X_shard.reshape((-1, 1)))
mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
# Compute next sentence prediction loss
nsp_l = loss(nsp_Y_hat, nsp_y_shard)
nsp_l = nsp_l.mean()
mlm_ls.append(mlm_l)
nsp_ls.append(nsp_l)
ls.append(mlm_l + nsp_l)
npx.waitall()
return mlm_ls, nsp_ls, ls
調用上述兩個輔助函數(shù),以下 函數(shù)定義了在 WikiText-2 ( ) 數(shù)據(jù)集上train_bert
預訓練 BERT ( ) 的過程。訓練 BERT 可能需要很長時間。與在函數(shù)中指定訓練的時期數(shù)不同 (參見第 14.1 節(jié)),以下函數(shù)的輸入指定訓練的迭代步數(shù)。net
train_iter
train_ch13
num_steps
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
net(*next(iter(train_iter))[:4])
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
trainer = torch.optim.Adam(net.parameters(), lr=0.01)
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
# Sum of masked language modeling losses, sum of next sentence prediction
# losses, no. of sentence pairs, count
metric = d2l.Accumulator(4)
num_steps_reached = False
while step < num_steps and not num_steps_reached:
for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
mlm_weights_X, mlm_Y, nsp_y in train_iter:
tokens_X = tokens_X.to(devices[0])
segments_X = segments_X.to(devices[0])
valid_lens_x = valid_lens_x.to(devices[0])
pred_positions_X = pred_positions_X.to(devices[0])
mlm_weights_X = mlm_weights_X.to(devices[0])
mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
trainer.zero_grad()
timer.start()
mlm_l, nsp_l, l = _get_batch_loss_bert(
net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
l.backward()
trainer.step()
metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
timer.stop()
animator.add(step + 1,
(metric[0] / metric[3], metric[1] / metric[3]))
step += 1
if step == num_steps:
num_steps_reached = True
break
print(f'MLM loss {metric[0] / metric[3]:.3f}, '
f'NSP loss {metric[1] / metric[3]:.3f}')
print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
f'{str(devices)}')
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
trainer = gluon.Trainer(net.collect_params(), 'adam',
{'learning_rate': 0.01})
step, timer = 0, d2l.Timer()
animator = d2l.Animator(xlabel='step', ylabel='loss',
xlim=[1, num_steps], legend=['mlm', 'nsp'])
# Sum of masked language modeling losses, sum of next sentence prediction
# losses, no. of sentence pairs, count
metric = d2l.
評論
查看更多