0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

PyTorch教程-10.5。機(jī)器翻譯和數(shù)據(jù)集

jf_pJlTbmA9 ? 來源:PyTorch ? 作者:PyTorch ? 2023-06-05 15:44 ? 次閱讀

在引起人們對現(xiàn)代 RNN 廣泛興趣的重大突破中,有一項(xiàng)是統(tǒng)計(jì)機(jī)器翻譯應(yīng)用領(lǐng)域的重大進(jìn)展 。在這里,模型以一種語言的句子呈現(xiàn),并且必須預(yù)測另一種語言的相應(yīng)句子。請注意,由于兩種語言的語法結(jié)構(gòu)不同,這里的句子可能有不同的長度,并且兩個(gè)句子中相應(yīng)的詞可能不會(huì)以相同的順序出現(xiàn)。

許多問題都具有這種在兩個(gè)這樣的“未對齊”序列之間進(jìn)行映射的風(fēng)格。示例包括從對話提示到回復(fù)或從問題到答案的映射。廣義上,此類問題稱為 序列到序列(seq2seq) 問題,它們是本章剩余部分和 第 11 節(jié)大部分內(nèi)容的重點(diǎn)。

在本節(jié)中,我們將介紹機(jī)器翻譯問題和我們將在后續(xù)示例中使用的示例數(shù)據(jù)集。幾十年來,語言間翻譯的統(tǒng)計(jì)公式一直很流行 (Brown等人,1990 年,Brown等人,1988 年),甚至在研究人員使神經(jīng)網(wǎng)絡(luò)方法起作用之前(這些方法通常被統(tǒng)稱為神經(jīng)機(jī)器翻譯)。

首先,我們需要一些新代碼來處理我們的數(shù)據(jù)。與我們在9.3 節(jié)中看到的語言建模不同,這里的每個(gè)示例都包含兩個(gè)單獨(dú)的文本序列,一個(gè)是源語言,另一個(gè)(翻譯)是目標(biāo)語言。以下代碼片段將展示如何將預(yù)處理后的數(shù)據(jù)加載到小批量中進(jìn)行訓(xùn)練。

import os
import torch
from d2l import torch as d2l

import os
from mxnet import np, npx
from d2l import mxnet as d2l

npx.set_np()

import os
from jax import numpy as jnp
from d2l import jax as d2l

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

import os
import tensorflow as tf
from d2l import tensorflow as d2l

10.5.1。下載和預(yù)處理數(shù)據(jù)集

首先,我們 從 Tatoeba Project 下載由雙語句子對組成的英法數(shù)據(jù)集。數(shù)據(jù)集中的每一行都是一個(gè)制表符分隔的對,由一個(gè)英文文本序列和翻譯后的法文文本序列組成。請注意,每個(gè)文本序列可以只是一個(gè)句子,也可以是一段多句。在這個(gè)英語翻譯成法語的機(jī)器翻譯問題中,英語被稱為源語言,法語被稱為目標(biāo)語言。

class MTFraEng(d2l.DataModule): #@save
  """The English-French dataset."""
  def _download(self):
    d2l.extract(d2l.download(
      d2l.DATA_URL+'fra-eng.zip', self.root,
      '94646ad1522d915e7b0f9296181140edcf86a4f5'))
    with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
      return f.read()

data = MTFraEng()
raw_text = data._download()
print(raw_text[:75])

Downloading ../data/fra-eng.zip from http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip...
Go. Va !
Hi. Salut !
Run!    Cours?!
Run!    Courez?!
Who?    Qui ?
Wow!    ?a alors?!

class MTFraEng(d2l.DataModule): #@save
  """The English-French dataset."""
  def _download(self):
    d2l.extract(d2l.download(
      d2l.DATA_URL+'fra-eng.zip', self.root,
      '94646ad1522d915e7b0f9296181140edcf86a4f5'))
    with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
      return f.read()

data = MTFraEng()
raw_text = data._download()
print(raw_text[:75])

Go. Va !
Hi. Salut !
Run!    Cours?!
Run!    Courez?!
Who?    Qui ?
Wow!    ?a alors?!

class MTFraEng(d2l.DataModule): #@save
  """The English-French dataset."""
  def _download(self):
    d2l.extract(d2l.download(
      d2l.DATA_URL+'fra-eng.zip', self.root,
      '94646ad1522d915e7b0f9296181140edcf86a4f5'))
    with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
      return f.read()

data = MTFraEng()
raw_text = data._download()
print(raw_text[:75])

Go. Va !
Hi. Salut !
Run!    Cours?!
Run!    Courez?!
Who?    Qui ?
Wow!    ?a alors?!

class MTFraEng(d2l.DataModule): #@save
  """The English-French dataset."""
  def _download(self):
    d2l.extract(d2l.download(
      d2l.DATA_URL+'fra-eng.zip', self.root,
      '94646ad1522d915e7b0f9296181140edcf86a4f5'))
    with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f:
      return f.read()

data = MTFraEng()
raw_text = data._download()
print(raw_text[:75])

Go. Va !
Hi. Salut !
Run!    Cours?!
Run!    Courez?!
Who?    Qui ?
Wow!    ?a alors?!

下載數(shù)據(jù)集后,我們對原始文本數(shù)據(jù)進(jìn)行幾個(gè)預(yù)處理步驟。例如,我們將不間斷空格替換為空格,將大寫字母轉(zhuǎn)換為小寫字母,在單詞和標(biāo)點(diǎn)符號(hào)之間插入空格。

@d2l.add_to_class(MTFraEng) #@save
def _preprocess(self, text):
  # Replace non-breaking space with space
  text = text.replace('u202f', ' ').replace('xa0', ' ')
  # Insert space between words and punctuation marks
  no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
  out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
      for i, char in enumerate(text.lower())]
  return ''.join(out)

text = data._preprocess(raw_text)
print(text[:80])

go .    va !
hi .    salut !
run !    cours !
run !    courez !
who ?    qui ?
wow !    ?a alors !

@d2l.add_to_class(MTFraEng) #@save
def _preprocess(self, text):
  # Replace non-breaking space with space
  text = text.replace('u202f', ' ').replace('xa0', ' ')
  # Insert space between words and punctuation marks
  no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
  out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
      for i, char in enumerate(text.lower())]
  return ''.join(out)

text = data._preprocess(raw_text)
print(text[:80])

go .    va !
hi .    salut !
run !    cours !
run !    courez !
who ?    qui ?
wow !    ?a alors !

@d2l.add_to_class(MTFraEng) #@save
def _preprocess(self, text):
  # Replace non-breaking space with space
  text = text.replace('u202f', ' ').replace('xa0', ' ')
  # Insert space between words and punctuation marks
  no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
  out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
      for i, char in enumerate(text.lower())]
  return ''.join(out)

text = data._preprocess(raw_text)
print(text[:80])

go .    va !
hi .    salut !
run !    cours !
run !    courez !
who ?    qui ?
wow !    ?a alors !

@d2l.add_to_class(MTFraEng) #@save
def _preprocess(self, text):
  # Replace non-breaking space with space
  text = text.replace('u202f', ' ').replace('xa0', ' ')
  # Insert space between words and punctuation marks
  no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' '
  out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
      for i, char in enumerate(text.lower())]
  return ''.join(out)

text = data._preprocess(raw_text)
print(text[:80])

go .    va !
hi .    salut !
run !    cours !
run !    courez !
who ?    qui ?
wow !    ?a alors !

10.5.2。代幣化

與第 9.3 節(jié)中的字符級標(biāo)記化不同 ,對于機(jī)器翻譯,我們在這里更喜歡單詞級標(biāo)記化(當(dāng)今最先進(jìn)的模型使用更復(fù)雜的標(biāo)記化技術(shù))。以下_tokenize方法對第一個(gè)max_examples文本序列對進(jìn)行分詞,其中每個(gè)分詞要么是一個(gè)單詞,要么是一個(gè)標(biāo)點(diǎn)符號(hào)。我們將特殊的“”標(biāo)記附加到每個(gè)序列的末尾,以指示序列的結(jié)束。當(dāng)模型通過生成一個(gè)接一個(gè)標(biāo)記的序列標(biāo)記進(jìn)行預(yù)測時(shí),“”標(biāo)記的生成表明輸出序列是完整的。最后,下面的方法返回兩個(gè)令牌列表列表:src和tgt。具體來說,src[i]是來自ith源語言(此處為英語)的文本序列和tgt[i]目標(biāo)語言(此處為法語)的文本序列。

@d2l.add_to_class(MTFraEng) #@save
def _tokenize(self, text, max_examples=None):
  src, tgt = [], []
  for i, line in enumerate(text.split('n')):
    if max_examples and i > max_examples: break
    parts = line.split('t')
    if len(parts) == 2:
      # Skip empty tokens
      src.append([t for t in f'{parts[0]} '.split(' ') if t])
      tgt.append([t for t in f'{parts[1]} '.split(' ') if t])
  return src, tgt

src, tgt = data._tokenize(text)
src[:6], tgt[:6]

([['go', '.', ''],
 ['hi', '.', ''],
 ['run', '!', ''],
 ['run', '!', ''],
 ['who', '?', ''],
 ['wow', '!', '']],
 [['va', '!', ''],
 ['salut', '!', ''],
 ['cours', '!', ''],
 ['courez', '!', ''],
 ['qui', '?', ''],
 ['?a', 'alors', '!', '']])

@d2l.add_to_class(MTFraEng) #@save
def _tokenize(self, text, max_examples=None):
  src, tgt = [], []
  for i, line in enumerate(text.split('n')):
    if max_examples and i > max_examples: break
    parts = line.split('t')
    if len(parts) == 2:
      # Skip empty tokens
      src.append([t for t in f'{parts[0]} '.split(' ') if t])
      tgt.append([t for t in f'{parts[1]} '.split(' ') if t])
  return src, tgt

src, tgt = data._tokenize(text)
src[:6], tgt[:6]

([['go', '.', ''],
 ['hi', '.', ''],
 ['run', '!', ''],
 ['run', '!', ''],
 ['who', '?', ''],
 ['wow', '!', '']],
 [['va', '!', ''],
 ['salut', '!', ''],
 ['cours', '!', ''],
 ['courez', '!', ''],
 ['qui', '?', ''],
 ['?a', 'alors', '!', '']])

@d2l.add_to_class(MTFraEng) #@save
def _tokenize(self, text, max_examples=None):
  src, tgt = [], []
  for i, line in enumerate(text.split('n')):
    if max_examples and i > max_examples: break
    parts = line.split('t')
    if len(parts) == 2:
      # Skip empty tokens
      src.append([t for t in f'{parts[0]} '.split(' ') if t])
      tgt.append([t for t in f'{parts[1]} '.split(' ') if t])
  return src, tgt

src, tgt = data._tokenize(text)
src[:6], tgt[:6]

([['go', '.', ''],
 ['hi', '.', ''],
 ['run', '!', ''],
 ['run', '!', ''],
 ['who', '?', ''],
 ['wow', '!', '']],
 [['va', '!', ''],
 ['salut', '!', ''],
 ['cours', '!', ''],
 ['courez', '!', ''],
 ['qui', '?', ''],
 ['?a', 'alors', '!', '']])

@d2l.add_to_class(MTFraEng) #@save
def _tokenize(self, text, max_examples=None):
  src, tgt = [], []
  for i, line in enumerate(text.split('n')):
    if max_examples and i > max_examples: break
    parts = line.split('t')
    if len(parts) == 2:
      # Skip empty tokens
      src.append([t for t in f'{parts[0]} '.split(' ') if t])
      tgt.append([t for t in f'{parts[1]} '.split(' ') if t])
  return src, tgt

src, tgt = data._tokenize(text)
src[:6], tgt[:6]

([['go', '.', ''],
 ['hi', '.', ''],
 ['run', '!', ''],
 ['run', '!', ''],
 ['who', '?', ''],
 ['wow', '!', '']],
 [['va', '!', ''],
 ['salut', '!', ''],
 ['cours', '!', ''],
 ['courez', '!', ''],
 ['qui', '?', ''],
 ['?a', 'alors', '!', '']])

讓我們繪制每個(gè)文本序列的標(biāo)記數(shù)量的直方圖。在這個(gè)簡單的英法數(shù)據(jù)集中,大多數(shù)文本序列的標(biāo)記少于 20 個(gè)。

#@save
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
  """Plot the histogram for list length pairs."""
  d2l.set_figsize()
  _, _, patches = d2l.plt.hist(
    [[len(l) for l in xlist], [len(l) for l in ylist]])
  d2l.plt.xlabel(xlabel)
  d2l.plt.ylabel(ylabel)
  for patch in patches[1].patches:
    patch.set_hatch('/')
  d2l.plt.legend(legend)

show_list_len_pair_hist(['source', 'target'], '# tokens per sequence',
            'count', src, tgt);

poYBAGR9N2iABw1YAAF3AJUP1Cw291.svg

#@save
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
  """Plot the histogram for list length pairs."""
  d2l.set_figsize()
  _, _, patches = d2l.plt.hist(
    [[len(l) for l in xlist], [len(l) for l in ylist]])
  d2l.plt.xlabel(xlabel)
  d2l.plt.ylabel(ylabel)
  for patch in patches[1].patches:
    patch.set_hatch('/')
  d2l.plt.legend(legend)

show_list_len_pair_hist(['source', 'target'], '# tokens per sequence',
            'count', src, tgt);

poYBAGR9N2iABw1YAAF3AJUP1Cw291.svg

#@save
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
  """Plot the histogram for list length pairs."""
  d2l.set_figsize()
  _, _, patches = d2l.plt.hist(
    [[len(l) for l in xlist], [len(l) for l in ylist]])
  d2l.plt.xlabel(xlabel)
  d2l.plt.ylabel(ylabel)
  for patch in patches[1].patches:
    patch.set_hatch('/')
  d2l.plt.legend(legend)

show_list_len_pair_hist(['source', 'target'], '# tokens per sequence',
            'count', src, tgt);

poYBAGR9N2iABw1YAAF3AJUP1Cw291.svg

#@save
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
  """Plot the histogram for list length pairs."""
  d2l.set_figsize()
  _, _, patches = d2l.plt.hist(
    [[len(l) for l in xlist], [len(l) for l in ylist]])
  d2l.plt.xlabel(xlabel)
  d2l.plt.ylabel(ylabel)
  for patch in patches[1].patches:
    patch.set_hatch('/')
  d2l.plt.legend(legend)

show_list_len_pair_hist(['source', 'target'], '# tokens per sequence',
            'count', src, tgt);

poYBAGR9N2iABw1YAAF3AJUP1Cw291.svg

10.5.3。固定長度的加載序列

回想一下,在語言建模中,每個(gè)示例序列(一個(gè)句子的一部分或多個(gè)句子的跨度)都有固定的長度。這是由第 9.3 節(jié)num_steps中的(時(shí)間步數(shù)或標(biāo)記數(shù))參數(shù)指定的。在機(jī)器翻譯中,每個(gè)示例都是一對源文本序列和目標(biāo)文本序列,其中這兩個(gè)文本序列可能具有不同的長度。

為了計(jì)算效率,我們?nèi)匀豢梢酝ㄟ^截?cái)嗪吞畛湟淮翁幚硪恍∨谋拘蛄小<僭O(shè)同一個(gè)小批量中的每個(gè)序列都應(yīng)該具有相同的長度 num_steps。如果文本序列少于num_steps標(biāo)記,我們將繼續(xù)在其末尾附加特殊的“”標(biāo)記,直到其長度達(dá)到num_steps. num_steps否則,我們將通過僅獲取其第一個(gè)標(biāo)記并丟棄其余標(biāo)記來截?cái)辔谋拘蛄?。這樣,每個(gè)文本序列將具有相同的長度,以相同形狀的小批量加載。此外,我們還記錄了不包括填充標(biāo)記的源序列長度。我們稍后將介紹的某些模型將需要此信息。

由于機(jī)器翻譯數(shù)據(jù)集由多對語言組成,我們可以分別為源語言和目標(biāo)語言構(gòu)建兩個(gè)詞匯表。使用詞級標(biāo)記化,詞匯量將明顯大于使用字符級標(biāo)記化的詞匯量。為了減輕這一點(diǎn),在這里,我們將看上去不到2倍相同的未知數(shù)(“”)令牌視為罕見的令牌。正如我們稍后將解釋的(圖 10.7.1),當(dāng)使用目標(biāo)序列進(jìn)行訓(xùn)練時(shí),解碼器輸出(標(biāo)簽標(biāo)記)可以是相同的解碼器輸入(目標(biāo)標(biāo)記),移動(dòng)一個(gè)標(biāo)記;特殊的序列開頭“”標(biāo)記將用作預(yù)測目標(biāo)序列的第一個(gè)輸入標(biāo)記(圖 10.7.3)。

@d2l.add_to_class(MTFraEng) #@save
def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
  super(MTFraEng, self).__init__()
  self.save_hyperparameters()
  self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
    self._download())

@d2l.add_to_class(MTFraEng) #@save
def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
  def _build_array(sentences, vocab, is_tgt=False):
    pad_or_trim = lambda seq, t: (
      seq[:t] if len(seq) > t else seq + [''] * (t - len(seq)))
    sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
    if is_tgt:
      sentences = [[''] + s for s in sentences]
    if vocab is None:
      vocab = d2l.Vocab(sentences, min_freq=2)
    array = torch.tensor([vocab[s] for s in sentences])
    valid_len = (array != vocab['']).type(torch.int32).sum(1)
    return array, vocab, valid_len
  src, tgt = self._tokenize(self._preprocess(raw_text),
               self.num_train + self.num_val)
  src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
  tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
  return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),
      src_vocab, tgt_vocab)

@d2l.add_to_class(MTFraEng) #@save
def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
  super(MTFraEng, self).__init__()
  self.save_hyperparameters()
  self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
    self._download())

@d2l.add_to_class(MTFraEng) #@save
def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
  def _build_array(sentences, vocab, is_tgt=False):
    pad_or_trim = lambda seq, t: (
      seq[:t] if len(seq) > t else seq + [''] * (t - len(seq)))
    sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
    if is_tgt:
      sentences = [[''] + s for s in sentences]
    if vocab is None:
      vocab = d2l.Vocab(sentences, min_freq=2)
    array = np.array([vocab[s] for s in sentences])
    valid_len = (array != vocab['']).astype(np.int32).sum(1)
    return array, vocab, valid_len
  src, tgt = self._tokenize(self._preprocess(raw_text),
               self.num_train + self.num_val)
  src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
  tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
  return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),
      src_vocab, tgt_vocab)

@d2l.add_to_class(MTFraEng) #@save
def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
  super(MTFraEng, self).__init__()
  self.save_hyperparameters()
  self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
    self._download())

@d2l.add_to_class(MTFraEng) #@save
def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
  def _build_array(sentences, vocab, is_tgt=False):
    pad_or_trim = lambda seq, t: (
      seq[:t] if len(seq) > t else seq + [''] * (t - len(seq)))
    sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
    if is_tgt:
      sentences = [[''] + s for s in sentences]
    if vocab is None:
      vocab = d2l.Vocab(sentences, min_freq=2)
    array = jnp.array([vocab[s] for s in sentences])
    valid_len = (array != vocab['']).astype(jnp.int32).sum(1)
    return array, vocab, valid_len
  src, tgt = self._tokenize(self._preprocess(raw_text),
               self.num_train + self.num_val)
  src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
  tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
  return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),
      src_vocab, tgt_vocab)

@d2l.add_to_class(MTFraEng) #@save
def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128):
  super(MTFraEng, self).__init__()
  self.save_hyperparameters()
  self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays(
    self._download())

@d2l.add_to_class(MTFraEng) #@save
def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None):
  def _build_array(sentences, vocab, is_tgt=False):
    pad_or_trim = lambda seq, t: (
      seq[:t] if len(seq) > t else seq + [''] * (t - len(seq)))
    sentences = [pad_or_trim(s, self.num_steps) for s in sentences]
    if is_tgt:
      sentences = [[''] + s for s in sentences]
    if vocab is None:
      vocab = d2l.Vocab(sentences, min_freq=2)
    array = tf.constant([vocab[s] for s in sentences])
    valid_len = tf.reduce_sum(
      tf.cast(array != vocab[''], tf.int32), 1)
    return array, vocab, valid_len
  src, tgt = self._tokenize(self._preprocess(raw_text),
               self.num_train + self.num_val)
  src_array, src_vocab, src_valid_len = _build_array(src, src_vocab)
  tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True)
  return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]),
      src_vocab, tgt_vocab)

10.5.4。讀取數(shù)據(jù)集

最后,我們定義get_dataloader返回?cái)?shù)據(jù)迭代器的方法。

@d2l.add_to_class(MTFraEng) #@save
def get_dataloader(self, train):
  idx = slice(0, self.num_train) if train else slice(self.num_train, None)
  return self.get_tensorloader(self.arrays, train, idx)

讓我們從英法數(shù)據(jù)集中讀取第一個(gè)小批量。

data = MTFraEng(batch_size=3)
src, tgt, src_valid_len, label = next(iter(data.train_dataloader()))
print('source:', src.type(torch.int32))
print('decoder input:', tgt.type(torch.int32))
print('source len excluding pad:', src_valid_len.type(torch.int32))
print('label:', label.type(torch.int32))

source: tensor([[ 83, 174,  2,  3,  4,  4,  4,  4,  4],
    [ 84, 32, 91,  2,  3,  4,  4,  4,  4],
    [144, 174,  0,  3,  4,  4,  4,  4,  4]], dtype=torch.int32)
decoder input: tensor([[ 3,  6,  0,  4,  5,  5,  5,  5,  5],
    [ 3, 108, 112, 84,  2,  4,  5,  5,  5],
    [ 3, 87,  0,  4,  5,  5,  5,  5,  5]], dtype=torch.int32)
source len excluding pad: tensor([4, 5, 4], dtype=torch.int32)
label: tensor([[ 6,  0,  4,  5,  5,  5,  5,  5,  5],
    [108, 112, 84,  2,  4,  5,  5,  5,  5],
    [ 87,  0,  4,  5,  5,  5,  5,  5,  5]], dtype=torch.int32)

data = MTFraEng(batch_size=3)
src, tgt, src_valid_len, label = next(iter(data.train_dataloader()))
print('source:', src.astype(np.int32))
print('decoder input:', tgt.astype(np.int32))
print('source len excluding pad:', src_valid_len.astype(np.int32))
print('label:', label.astype(np.int32))

source: [[84 5 2 3 4 4 4 4 4]
 [84 5 2 3 4 4 4 4 4]
 [59 9 2 3 4 4 4 4 4]]
decoder input: [[ 3 108  6  2  4  5  5  5  5]
 [ 3 105  6  2  4  5  5  5  5]
 [ 3 203  0  4  5  5  5  5  5]]
source len excluding pad: [4 4 4]
label: [[108  6  2  4  5  5  5  5  5]
 [105  6  2  4  5  5  5  5  5]
 [203  0  4  5  5  5  5  5  5]]

data = MTFraEng(batch_size=3)
src, tgt, src_valid_len, label = next(iter(data.train_dataloader()))
print('source:', src.astype(jnp.int32))
print('decoder input:', tgt.astype(jnp.int32))
print('source len excluding pad:', src_valid_len.astype(jnp.int32))
print('label:', label.astype(jnp.int32))

source: [[ 86 43  2  3  4  4  4  4  4]
 [176 165  2  3  4  4  4  4  4]
 [143 111  2  3  4  4  4  4  4]]
decoder input: [[ 3 108 183 98  2  4  5  5  5]
 [ 3  6 42  0  4  5  5  5  5]
 [ 3  6  0  4  5  5  5  5  5]]
source len excluding pad: [4 4 4]
label: [[108 183 98  2  4  5  5  5  5]
 [ 6 42  0  4  5  5  5  5  5]
 [ 6  0  4  5  5  5  5  5  5]]

data = MTFraEng(batch_size=3)
src, tgt, src_valid_len, label = next(iter(data.train_dataloader()))
print('source:', tf.cast(src, tf.int32))
print('decoder input:', tf.cast(tgt, tf.int32))
print('source len excluding pad:', tf.cast(src_valid_len, tf.int32))
print('label:', tf.cast(label, tf.int32))

source: tf.Tensor(
[[ 92 115  2  3  4  4  4  4  4]
 [155 168  2  3  4  4  4  4  4]
 [ 86 121  2  3  4  4  4  4  4]], shape=(3, 9), dtype=int32)
decoder input: tf.Tensor(
[[ 3 37  6  2  4  5  5  5  5]
 [ 3  6 192  2  4  5  5  5  5]
 [ 3 108 202 30  2  4  5  5  5]], shape=(3, 9), dtype=int32)
source len excluding pad: tf.Tensor([4 4 4], shape=(3,), dtype=int32)
label: tf.Tensor(
[[ 37  6  2  4  5  5  5  5  5]
 [ 6 192  2  4  5  5  5  5  5]
 [108 202 30  2  4  5  5  5  5]], shape=(3, 9), dtype=int32)

下面我們展示了一對由上述方法處理的源序列和目標(biāo)序列_build_arrays(以字符串格式)。

@d2l.add_to_class(MTFraEng) #@save
def build(self, src_sentences, tgt_sentences):
  raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip(
    src_sentences, tgt_sentences)])
  arrays, _, _ = self._build_arrays(
    raw_text, self.src_vocab, self.tgt_vocab)
  return arrays

src, tgt, _, _ = data.build(['hi .'], ['salut .'])
print('source:', data.src_vocab.to_tokens(src[0].type(torch.int32)))
print('target:', data.tgt_vocab.to_tokens(tgt[0].type(torch.int32)))

source: ['hi', '.', '', '', '', '', '', '', '']
target: ['', 'salut', '.', '', '', '', '', '', '']

@d2l.add_to_class(MTFraEng) #@save
def build(self, src_sentences, tgt_sentences):
  raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip(
    src_sentences, tgt_sentences)])
  arrays, _, _ = self._build_arrays(
    raw_text, self.src_vocab, self.tgt_vocab)
  return arrays

src, tgt, _, _ = data.build(['hi .'], ['salut .'])
print('source:', data.src_vocab.to_tokens(src[0].astype(np.int32)))
print('target:', data.tgt_vocab.to_tokens(tgt[0].astype(np.int32)))

source: ['hi', '.', '', '', '', '', '', '', '']
target: ['', 'salut', '.', '', '', '', '', '', '']

@d2l.add_to_class(MTFraEng) #@save
def build(self, src_sentences, tgt_sentences):
  raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip(
    src_sentences, tgt_sentences)])
  arrays, _, _ = self._build_arrays(
    raw_text, self.src_vocab, self.tgt_vocab)
  return arrays

src, tgt, _, _ = data.build(['hi .'], ['salut .'])
print('source:', data.src_vocab.to_tokens(src[0].astype(jnp.int32)))
print('target:', data.tgt_vocab.to_tokens(tgt[0].astype(jnp.int32)))

source: ['hi', '.', '', '', '', '', '', '', '']
target: ['', 'salut', '.', '', '', '', '', '', '']

@d2l.add_to_class(MTFraEng) #@save
def build(self, src_sentences, tgt_sentences):
  raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip(
    src_sentences, tgt_sentences)])
  arrays, _, _ = self._build_arrays(
    raw_text, self.src_vocab, self.tgt_vocab)
  return arrays

src, tgt, _, _ = data.build(['hi .'], ['salut .'])
print('source:', data.src_vocab.to_tokens(tf.cast(src[0], tf.int32)))
print('target:', data.tgt_vocab.to_tokens(tf.cast(tgt[0], tf.int32)))

source: ['hi', '.', '', '', '', '', '', '', '']
target: ['', 'salut', '.', '', '', '', '', '', '']

10.5.5。概括

在自然語言處理中,機(jī)器翻譯是指將代表源語言文本字符串的序列自動(dòng)映射到代表目標(biāo)語言中合理翻譯的字符串的任務(wù)。使用詞級標(biāo)記化,詞匯量將明顯大于使用字符級標(biāo)記化,但序列長度會(huì)短得多。為了減輕大詞匯量,我們可以將不常見的標(biāo)記視為一些“未知”標(biāo)記。我們可以截?cái)嗪吞畛湮谋拘蛄?,以便它們都具有相同的長度以加載到小批量中。現(xiàn)代實(shí)現(xiàn)通常將具有相似長度的序列存儲(chǔ)起來,以避免在填充上浪費(fèi)過多的計(jì)算。

10.5.6。練習(xí)

max_examples在 方法中嘗試不同的參數(shù)值_tokenize。這如何影響源語言和目標(biāo)語言的詞匯量?

某些語言(例如中文和日文)中的文本沒有字界指示符(例如,空格)。對于這種情況,詞級標(biāo)記化仍然是一個(gè)好主意嗎?為什么或者為什么不?

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報(bào)投訴
  • 數(shù)據(jù)集
    +關(guān)注

    關(guān)注

    4

    文章

    1209

    瀏覽量

    24772
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    808

    瀏覽量

    13314
收藏 人收藏

    評論

    相關(guān)推薦

    機(jī)器翻譯三大核心技術(shù)原理 | AI知識(shí)科普

    ,應(yīng)用場景越多,需要的規(guī)則也越來越多,規(guī)則之間的沖突也逐漸出現(xiàn)。于是很多科研學(xué)家開始思考,是否能讓機(jī)器自動(dòng)從數(shù)據(jù)庫里學(xué)習(xí)相應(yīng)的規(guī)則,1993年IBM提出基于詞的統(tǒng)計(jì)翻譯模型標(biāo)志著第二代機(jī)器翻譯
    發(fā)表于 07-06 10:30

    神經(jīng)機(jī)器翻譯的方法有哪些?

    目前,神經(jīng)機(jī)器翻譯(NMT)已經(jīng)成為在學(xué)術(shù)界和工業(yè)界最先進(jìn)的機(jī)器翻譯方法。最初的這種基于編碼器-解碼器架構(gòu)的機(jī)器翻譯系統(tǒng)都針對單個(gè)語言對進(jìn)行翻譯。近期的工作開始探索去擴(kuò)展這種辦法以支持
    發(fā)表于 11-23 12:14

    從冷戰(zhàn)到深度學(xué)習(xí),機(jī)器翻譯歷史不簡單!

    深度學(xué)習(xí)機(jī)器翻譯 實(shí)現(xiàn)高質(zhì)量機(jī)器翻譯的夢想已經(jīng)存在了很多年,很多科學(xué)家都為這一夢想貢獻(xiàn)了自己的時(shí)間和心力。從早期的基于規(guī)則的機(jī)器翻譯到如今廣泛應(yīng)用的神經(jīng)機(jī)器翻譯,
    發(fā)表于 09-17 09:23 ?433次閱讀

    機(jī)器翻譯的真實(shí)水平如何,夢想與現(xiàn)實(shí)的距離到底有多遠(yuǎn)?

    近年來,隨著計(jì)算機(jī)性能的提高,云計(jì)算、大數(shù)據(jù)機(jī)器學(xué)習(xí)等相關(guān)技術(shù)迅速發(fā)展,人工智能再度崛起,機(jī)器翻譯重新成為人們關(guān)注的焦點(diǎn)。一時(shí)間,機(jī)器翻譯系統(tǒng)如雨后春筍般涌現(xiàn),各種報(bào)道隨之呈井噴式爆
    的頭像 發(fā)表于 03-22 14:08 ?4143次閱讀

    換個(gè)角度來聊機(jī)器翻譯

    同時(shí)期國內(nèi)科技企業(yè)在機(jī)器翻譯上的進(jìn)展也非常迅速,以語音和語義理解見長的科大訊飛在2014年國際口語翻譯大賽IWSLT上獲得中英和英中兩個(gè)翻譯方向的全球第一名,在2015年又在由美國國家標(biāo)準(zhǔn)技術(shù)研究院組織的
    的頭像 發(fā)表于 04-24 13:55 ?3518次閱讀
    換個(gè)角度來聊<b class='flag-5'>機(jī)器翻譯</b>

    機(jī)器翻譯走紅的背后是什么

    未來需要新的算法和語義層面的綜合性突破,促進(jìn)機(jī)器翻譯產(chǎn)品的迭代和產(chǎn)業(yè)全面升級。
    發(fā)表于 07-14 10:02 ?1034次閱讀

    未來機(jī)器翻譯會(huì)取代人工翻譯

    所謂機(jī)器翻譯,就是利用計(jì)算機(jī)將一種自然語言(源語言)轉(zhuǎn)換為另一種自然語言(目標(biāo)語言)的過程。它是計(jì)算語言學(xué)的一個(gè)分支,是人工智能的終極目標(biāo)之一,具有重要的科學(xué)研究價(jià)值。而且機(jī)器翻譯又具有重要
    的頭像 發(fā)表于 12-29 10:12 ?5059次閱讀

    基于句子級上下文的神經(jīng)機(jī)器翻譯綜述

    基于句子級上下文的神經(jīng)機(jī)器翻譯綜述
    發(fā)表于 06-29 16:26 ?64次下載

    Google遵循AI原則減少機(jī)器翻譯的性別偏見

    得益于神經(jīng)機(jī)器翻譯 (NMT) 的進(jìn)步,譯文更加自然流暢,但與此同時(shí),這些譯文也反映出訓(xùn)練數(shù)據(jù)存在社會(huì)偏見和刻板印象。因此,Google 持續(xù)致力于遵循 AI 原則,開發(fā)創(chuàng)新技術(shù),減少機(jī)器翻譯
    的頭像 發(fā)表于 08-24 10:14 ?2882次閱讀

    PyTorch教程10.5機(jī)器翻譯和數(shù)據(jù)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.5機(jī)器翻譯和數(shù)據(jù).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:14 ?0次下載
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>10.5</b>之<b class='flag-5'>機(jī)器翻譯</b><b class='flag-5'>和數(shù)據(jù)</b><b class='flag-5'>集</b>

    PyTorch教程10.7之用于機(jī)器翻譯的編碼器-解碼器Seq2Seq

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.7之用于機(jī)器翻譯的編碼器-解碼器Seq2Seq.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 18:14 ?0次下載
    <b class='flag-5'>PyTorch</b>教程10.7之用于<b class='flag-5'>機(jī)器翻譯</b>的編碼器-解碼器Seq2Seq

    PyTorch教程14.9之語義分割和數(shù)據(jù)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程14.9之語義分割和數(shù)據(jù).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 11:10 ?0次下載
    <b class='flag-5'>PyTorch</b>教程14.9之語義分割<b class='flag-5'>和數(shù)據(jù)</b><b class='flag-5'>集</b>

    PyTorch教程16.1之情緒分析和數(shù)據(jù)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程16.1之情緒分析和數(shù)據(jù).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 10:54 ?0次下載
    <b class='flag-5'>PyTorch</b>教程16.1之情緒分析<b class='flag-5'>和數(shù)據(jù)</b><b class='flag-5'>集</b>

    PyTorch教程16.4之自然語言推理和數(shù)據(jù)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程16.4之自然語言推理和數(shù)據(jù).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 10:57 ?0次下載
    <b class='flag-5'>PyTorch</b>教程16.4之自然語言推理<b class='flag-5'>和數(shù)據(jù)</b><b class='flag-5'>集</b>

    機(jī)器翻譯研究進(jìn)展

    成為主流,如神經(jīng)網(wǎng)絡(luò)機(jī)器翻譯。神經(jīng)網(wǎng)絡(luò)機(jī)器翻譯機(jī)器從大量數(shù)據(jù)中自動(dòng)學(xué)習(xí)翻譯知識(shí),而不依靠人類專家撰寫規(guī)則,可以顯著提升
    的頭像 發(fā)表于 07-06 11:19 ?868次閱讀
    <b class='flag-5'>機(jī)器翻譯</b>研究進(jìn)展