電子發(fā)燒友App

硬聲App

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示
創(chuàng)作
電子發(fā)燒友網(wǎng)>電子資料下載>電子資料>PyTorch教程15.10之預訓練BERT

PyTorch教程15.10之預訓練BERT

2023-06-05 | pdf | 0.15 MB | 次下載 | 免費

資料介紹

借助15.8 節(jié)中實現(xiàn)的 BERT 模型和15.9 節(jié)中從 WikiText-2 數(shù)據(jù)集生成的預訓練示例 ,我們將在本節(jié)中在 WikiText-2 數(shù)據(jù)集上預訓練 BERT。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import autograd, gluon, init, np, npx
from d2l import mxnet as d2l

npx.set_np()

首先,我們將 WikiText-2 數(shù)據(jù)集加載為用于屏蔽語言建模和下一句預測的小批量預訓練示例。批量大小為 512,BERT 輸入序列的最大長度為 64。請注意,在原始 BERT 模型中,最大長度為 512。

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...

15.10.1。預訓練 BERT

原始 BERT 有兩個不同模型大小的版本 Devlin et al. , 2018。基礎模型(BERTBASE) 使用 12 層(Transformer 編碼器塊),具有 768 個隱藏單元(隱藏大?。┖?12 個自注意力頭。大模型(BERTLARGE) 使用 24 層,有 1024 個隱藏單元和 16 個自注意力頭。值得注意的是,前者有 1.1 億個參數(shù),而后者有 3.4 億個參數(shù)。為了便于演示,我們定義了一個小型 BERT,使用 2 層、128 個隱藏單元和 2 個自注意力頭。

net = d2l.BERTModel(len(vocab), num_hiddens=128,
          ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss()
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
          num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()

在定義訓練循環(huán)之前,我們定義了一個輔助函數(shù) _get_batch_loss_bert。給定訓練示例的碎片,此函數(shù)計算掩碼語言建模和下一句預測任務的損失。請注意,BERT 預訓練的最終損失只是掩碼語言建模損失和下一句預測損失的總和。

#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
             segments_X, valid_lens_x,
             pred_positions_X, mlm_weights_X,
             mlm_Y, nsp_y):
  # Forward pass
  _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                 valid_lens_x.reshape(-1),
                 pred_positions_X)
  # Compute masked language model loss
  mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
  mlm_weights_X.reshape(-1, 1)
  mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
  # Compute next sentence prediction loss
  nsp_l = loss(nsp_Y_hat, nsp_y)
  l = mlm_l + nsp_l
  return mlm_l, nsp_l, l
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
             segments_X_shards, valid_lens_x_shards,
             pred_positions_X_shards, mlm_weights_X_shards,
             mlm_Y_shards, nsp_y_shards):
  mlm_ls, nsp_ls, ls = [], [], []
  for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
     pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
     nsp_y_shard) in zip(
    tokens_X_shards, segments_X_shards, valid_lens_x_shards,
    pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
    nsp_y_shards):
    # Forward pass
    _, mlm_Y_hat, nsp_Y_hat = net(
      tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
      pred_positions_X_shard)
    # Compute masked language model loss
    mlm_l = loss(
      mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
      mlm_weights_X_shard.reshape((-1, 1)))
    mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
    # Compute next sentence prediction loss
    nsp_l = loss(nsp_Y_hat, nsp_y_shard)
    nsp_l = nsp_l.mean()
    mlm_ls.append(mlm_l)
    nsp_ls.append(nsp_l)
    ls.append(mlm_l + nsp_l)
    npx.waitall()
  return mlm_ls, nsp_ls, ls

調用上述兩個輔助函數(shù),以下 函數(shù)定義了在 WikiText-2 ( ) 數(shù)據(jù)集上train_bert預訓練 BERT ( ) 的過程訓練 BERT 可能需要很長時間。與在函數(shù)中指定訓練的時期數(shù)不同 (參見第 14.1 節(jié)),以下函數(shù)的輸入指定訓練的迭代步數(shù)。nettrain_itertrain_ch13num_steps

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  net(*next(iter(train_iter))[:4])
  net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  trainer = torch.optim.Adam(net.parameters(), lr=0.01)
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.Accumulator(4)
  num_steps_reached = False
  while step < num_steps and not num_steps_reached:
    for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
      mlm_weights_X, mlm_Y, nsp_y in train_iter:
      tokens_X = tokens_X.to(devices[0])
      segments_X = segments_X.to(devices[0])
      valid_lens_x = valid_lens_x.to(devices[0])
      pred_positions_X = pred_positions_X.to(devices[0])
      mlm_weights_X = mlm_weights_X.to(devices[0])
      mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
      trainer.zero_grad()
      timer.start()
      mlm_l, nsp_l, l = _get_batch_loss_bert(
        net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
        pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
      l.backward()
      trainer.step()
      metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
      timer.stop()
      animator.add(step + 1,
             (metric[0] / metric[3], metric[1] / metric[3]))
      step += 1
      if step == num_steps:
        num_steps_reached = True
        break

  print(f'MLM loss {metric[0] / metric[3]:.3f}, '
     f'NSP loss {metric[1] / metric[3]:.3f}')
  print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
     f'{str(devices)}')
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  trainer = gluon.Trainer(net.collect_params(), 'adam',
              {'learning_rate': 0.01})
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.

下載該資料的人也在下載 下載該資料的人還在閱讀
更多 >

評論

查看更多

下載排行

本周

  1. 1山景DSP芯片AP8248A2數(shù)據(jù)手冊
  2. 1.06 MB  |  532次下載  |  免費
  3. 2RK3399完整板原理圖(支持平板,盒子VR)
  4. 3.28 MB  |  339次下載  |  免費
  5. 3TC358743XBG評估板參考手冊
  6. 1.36 MB  |  330次下載  |  免費
  7. 4DFM軟件使用教程
  8. 0.84 MB  |  295次下載  |  免費
  9. 5元宇宙深度解析—未來的未來-風口還是泡沫
  10. 6.40 MB  |  227次下載  |  免費
  11. 6迪文DGUS開發(fā)指南
  12. 31.67 MB  |  194次下載  |  免費
  13. 7元宇宙底層硬件系列報告
  14. 13.42 MB  |  182次下載  |  免費
  15. 8FP5207XR-G1中文應用手冊
  16. 1.09 MB  |  178次下載  |  免費

本月

  1. 1OrCAD10.5下載OrCAD10.5中文版軟件
  2. 0.00 MB  |  234315次下載  |  免費
  3. 2555集成電路應用800例(新編版)
  4. 0.00 MB  |  33566次下載  |  免費
  5. 3接口電路圖大全
  6. 未知  |  30323次下載  |  免費
  7. 4開關電源設計實例指南
  8. 未知  |  21549次下載  |  免費
  9. 5電氣工程師手冊免費下載(新編第二版pdf電子書)
  10. 0.00 MB  |  15349次下載  |  免費
  11. 6數(shù)字電路基礎pdf(下載)
  12. 未知  |  13750次下載  |  免費
  13. 7電子制作實例集錦 下載
  14. 未知  |  8113次下載  |  免費
  15. 8《LED驅動電路設計》 溫德爾著
  16. 0.00 MB  |  6656次下載  |  免費

總榜

  1. 1matlab軟件下載入口
  2. 未知  |  935054次下載  |  免費
  3. 2protel99se軟件下載(可英文版轉中文版)
  4. 78.1 MB  |  537798次下載  |  免費
  5. 3MATLAB 7.1 下載 (含軟件介紹)
  6. 未知  |  420027次下載  |  免費
  7. 4OrCAD10.5下載OrCAD10.5中文版軟件
  8. 0.00 MB  |  234315次下載  |  免費
  9. 5Altium DXP2002下載入口
  10. 未知  |  233046次下載  |  免費
  11. 6電路仿真軟件multisim 10.0免費下載
  12. 340992  |  191187次下載  |  免費
  13. 7十天學會AVR單片機與C語言視頻教程 下載
  14. 158M  |  183279次下載  |  免費
  15. 8proe5.0野火版下載(中文版免費下載)
  16. 未知  |  138040次下載  |  免費