在引起人們對現(xiàn)代 RNN 廣泛興趣的重大突破中,有一項(xiàng)是統(tǒng)計(jì)機(jī)器翻譯應(yīng)用領(lǐng)域的重大進(jìn)展 。在這里,模型以一種語言的句子呈現(xiàn),并且必須預(yù)測另一種語言的相應(yīng)句子。請注意,由于兩種語言的語法結(jié)構(gòu)不同,這里的句子可能有不同的長度,并且兩個(gè)句子中相應(yīng)的詞可能不會(huì)以相同的順序出現(xiàn)。
許多問題都具有這種在兩個(gè)這樣的“未對齊”序列之間進(jìn)行映射的風(fēng)格。示例包括從對話提示到回復(fù)或從問題到答案的映射。廣義上,此類問題稱為 序列到序列(seq2seq) 問題,它們是本章剩余部分和 第 11 節(jié)大部分內(nèi)容的重點(diǎn)。
在本節(jié)中,我們將介紹機(jī)器翻譯問題和我們將在后續(xù)示例中使用的示例數(shù)據(jù)集。幾十年來,語言間翻譯的統(tǒng)計(jì)公式一直很流行 (Brown等人,1990 年,Brown等人,1988 年),甚至在研究人員使神經(jīng)網(wǎng)絡(luò)方法起作用之前(這些方法通常被統(tǒng)稱為神經(jīng)機(jī)器翻譯)。
首先,我們需要一些新代碼來處理我們的數(shù)據(jù)。與我們在9.3 節(jié)中看到的語言建模不同,這里的每個(gè)示例都包含兩個(gè)單獨(dú)的文本序列,一個(gè)是源語言,另一個(gè)(翻譯)是目標(biāo)語言。以下代碼片段將展示如何將預(yù)處理后的數(shù)據(jù)加載到小批量中進(jìn)行訓(xùn)練。
import os import torch from d2l import torch as d2l
import os from mxnet import np, npx from d2l import mxnet as d2l npx.set_np()
import os from jax import numpy as jnp from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import os import tensorflow as tf from d2l import tensorflow as d2l
10.5.1。下載和預(yù)處理數(shù)據(jù)集
首先,我們 從 Tatoeba Project 下載由雙語句子對組成的英法數(shù)據(jù)集。數(shù)據(jù)集中的每一行都是一個(gè)制表符分隔的對,由一個(gè)英文文本序列和翻譯后的法文文本序列組成。請注意,每個(gè)文本序列可以只是一個(gè)句子,也可以是一段多句。在這個(gè)英語翻譯成法語的機(jī)器翻譯問題中,英語被稱為源語言,法語被稱為目標(biāo)語言。
class MTFraEng(d2l.DataModule): #@save """The English-French dataset.""" def _download(self): d2l.extract(d2l.download( d2l.DATA_URL+'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5')) with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f: return f.read() data = MTFraEng() raw_text = data._download() print(raw_text[:75])
Downloading ../data/fra-eng.zip from http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip... Go. Va ! Hi. Salut ! Run! Cours?! Run! Courez?! Who? Qui ? Wow! ?a alors?!
class MTFraEng(d2l.DataModule): #@save """The English-French dataset.""" def _download(self): d2l.extract(d2l.download( d2l.DATA_URL+'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5')) with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f: return f.read() data = MTFraEng() raw_text = data._download() print(raw_text[:75])
Go. Va ! Hi. Salut ! Run! Cours?! Run! Courez?! Who? Qui ? Wow! ?a alors?!
class MTFraEng(d2l.DataModule): #@save """The English-French dataset.""" def _download(self): d2l.extract(d2l.download( d2l.DATA_URL+'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5')) with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f: return f.read() data = MTFraEng() raw_text = data._download() print(raw_text[:75])
Go. Va ! Hi. Salut ! Run! Cours?! Run! Courez?! Who? Qui ? Wow! ?a alors?!
class MTFraEng(d2l.DataModule): #@save """The English-French dataset.""" def _download(self): d2l.extract(d2l.download( d2l.DATA_URL+'fra-eng.zip', self.root, '94646ad1522d915e7b0f9296181140edcf86a4f5')) with open(self.root + '/fra-eng/fra.txt', encoding='utf-8') as f: return f.read() data = MTFraEng() raw_text = data._download() print(raw_text[:75])
Go. Va ! Hi. Salut ! Run! Cours?! Run! Courez?! Who? Qui ? Wow! ?a alors?!
下載數(shù)據(jù)集后,我們對原始文本數(shù)據(jù)進(jìn)行幾個(gè)預(yù)處理步驟。例如,我們將不間斷空格替換為空格,將大寫字母轉(zhuǎn)換為小寫字母,在單詞和標(biāo)點(diǎn)符號(hào)之間插入空格。
@d2l.add_to_class(MTFraEng) #@save def _preprocess(self, text): # Replace non-breaking space with space text = text.replace('u202f', ' ').replace('xa0', ' ') # Insert space between words and punctuation marks no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' ' out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text.lower())] return ''.join(out) text = data._preprocess(raw_text) print(text[:80])
go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ?a alors !
@d2l.add_to_class(MTFraEng) #@save def _preprocess(self, text): # Replace non-breaking space with space text = text.replace('u202f', ' ').replace('xa0', ' ') # Insert space between words and punctuation marks no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' ' out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text.lower())] return ''.join(out) text = data._preprocess(raw_text) print(text[:80])
go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ?a alors !
@d2l.add_to_class(MTFraEng) #@save def _preprocess(self, text): # Replace non-breaking space with space text = text.replace('u202f', ' ').replace('xa0', ' ') # Insert space between words and punctuation marks no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' ' out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text.lower())] return ''.join(out) text = data._preprocess(raw_text) print(text[:80])
go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ?a alors !
@d2l.add_to_class(MTFraEng) #@save def _preprocess(self, text): # Replace non-breaking space with space text = text.replace('u202f', ' ').replace('xa0', ' ') # Insert space between words and punctuation marks no_space = lambda char, prev_char: char in ',.!?' and prev_char != ' ' out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text.lower())] return ''.join(out) text = data._preprocess(raw_text) print(text[:80])
go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ?a alors !
10.5.2。代幣化
與第 9.3 節(jié)中的字符級標(biāo)記化不同 ,對于機(jī)器翻譯,我們在這里更喜歡單詞級標(biāo)記化(當(dāng)今最先進(jìn)的模型使用更復(fù)雜的標(biāo)記化技術(shù))。以下_tokenize方法對第一個(gè)max_examples文本序列對進(jìn)行分詞,其中每個(gè)分詞要么是一個(gè)單詞,要么是一個(gè)標(biāo)點(diǎn)符號(hào)。我們將特殊的“”標(biāo)記附加到每個(gè)序列的末尾,以指示序列的結(jié)束。當(dāng)模型通過生成一個(gè)接一個(gè)標(biāo)記的序列標(biāo)記進(jìn)行預(yù)測時(shí),“”標(biāo)記的生成表明輸出序列是完整的。最后,下面的方法返回兩個(gè)令牌列表列表:src和tgt。具體來說,src[i]是來自ith源語言(此處為英語)的文本序列和tgt[i]目標(biāo)語言(此處為法語)的文本序列。
@d2l.add_to_class(MTFraEng) #@save def _tokenize(self, text, max_examples=None): src, tgt = [], [] for i, line in enumerate(text.split('n')): if max_examples and i > max_examples: break parts = line.split('t') if len(parts) == 2: # Skip empty tokens src.append([t for t in f'{parts[0]} '.split(' ') if t]) tgt.append([t for t in f'{parts[1]} '.split(' ') if t]) return src, tgt src, tgt = data._tokenize(text) src[:6], tgt[:6]
([['go', '.', ''], ['hi', '.', ''], ['run', '!', ''], ['run', '!', ''], ['who', '?', ''], ['wow', '!', '']], [['va', '!', ''], ['salut', '!', ''], ['cours', '!', ''], ['courez', '!', ''], ['qui', '?', ''], ['?a', 'alors', '!', '']])
@d2l.add_to_class(MTFraEng) #@save def _tokenize(self, text, max_examples=None): src, tgt = [], [] for i, line in enumerate(text.split('n')): if max_examples and i > max_examples: break parts = line.split('t') if len(parts) == 2: # Skip empty tokens src.append([t for t in f'{parts[0]} '.split(' ') if t]) tgt.append([t for t in f'{parts[1]} '.split(' ') if t]) return src, tgt src, tgt = data._tokenize(text) src[:6], tgt[:6]
([['go', '.', ''], ['hi', '.', ''], ['run', '!', ''], ['run', '!', ''], ['who', '?', ''], ['wow', '!', '']], [['va', '!', ''], ['salut', '!', ''], ['cours', '!', ''], ['courez', '!', ''], ['qui', '?', ''], ['?a', 'alors', '!', '']])
@d2l.add_to_class(MTFraEng) #@save def _tokenize(self, text, max_examples=None): src, tgt = [], [] for i, line in enumerate(text.split('n')): if max_examples and i > max_examples: break parts = line.split('t') if len(parts) == 2: # Skip empty tokens src.append([t for t in f'{parts[0]} '.split(' ') if t]) tgt.append([t for t in f'{parts[1]} '.split(' ') if t]) return src, tgt src, tgt = data._tokenize(text) src[:6], tgt[:6]
([['go', '.', ''], ['hi', '.', ''], ['run', '!', ''], ['run', '!', ''], ['who', '?', ''], ['wow', '!', '']], [['va', '!', ''], ['salut', '!', ''], ['cours', '!', ''], ['courez', '!', ''], ['qui', '?', ''], ['?a', 'alors', '!', '']])
@d2l.add_to_class(MTFraEng) #@save def _tokenize(self, text, max_examples=None): src, tgt = [], [] for i, line in enumerate(text.split('n')): if max_examples and i > max_examples: break parts = line.split('t') if len(parts) == 2: # Skip empty tokens src.append([t for t in f'{parts[0]} '.split(' ') if t]) tgt.append([t for t in f'{parts[1]} '.split(' ') if t]) return src, tgt src, tgt = data._tokenize(text) src[:6], tgt[:6]
([['go', '.', ''], ['hi', '.', ''], ['run', '!', ''], ['run', '!', ''], ['who', '?', ''], ['wow', '!', '']], [['va', '!', ''], ['salut', '!', ''], ['cours', '!', ''], ['courez', '!', ''], ['qui', '?', ''], ['?a', 'alors', '!', '']])
讓我們繪制每個(gè)文本序列的標(biāo)記數(shù)量的直方圖。在這個(gè)簡單的英法數(shù)據(jù)集中,大多數(shù)文本序列的標(biāo)記少于 20 個(gè)。
#@save def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist): """Plot the histogram for list length pairs.""" d2l.set_figsize() _, _, patches = d2l.plt.hist( [[len(l) for l in xlist], [len(l) for l in ylist]]) d2l.plt.xlabel(xlabel) d2l.plt.ylabel(ylabel) for patch in patches[1].patches: patch.set_hatch('/') d2l.plt.legend(legend) show_list_len_pair_hist(['source', 'target'], '# tokens per sequence', 'count', src, tgt);
#@save def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist): """Plot the histogram for list length pairs.""" d2l.set_figsize() _, _, patches = d2l.plt.hist( [[len(l) for l in xlist], [len(l) for l in ylist]]) d2l.plt.xlabel(xlabel) d2l.plt.ylabel(ylabel) for patch in patches[1].patches: patch.set_hatch('/') d2l.plt.legend(legend) show_list_len_pair_hist(['source', 'target'], '# tokens per sequence', 'count', src, tgt);
#@save def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist): """Plot the histogram for list length pairs.""" d2l.set_figsize() _, _, patches = d2l.plt.hist( [[len(l) for l in xlist], [len(l) for l in ylist]]) d2l.plt.xlabel(xlabel) d2l.plt.ylabel(ylabel) for patch in patches[1].patches: patch.set_hatch('/') d2l.plt.legend(legend) show_list_len_pair_hist(['source', 'target'], '# tokens per sequence', 'count', src, tgt);
#@save def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist): """Plot the histogram for list length pairs.""" d2l.set_figsize() _, _, patches = d2l.plt.hist( [[len(l) for l in xlist], [len(l) for l in ylist]]) d2l.plt.xlabel(xlabel) d2l.plt.ylabel(ylabel) for patch in patches[1].patches: patch.set_hatch('/') d2l.plt.legend(legend) show_list_len_pair_hist(['source', 'target'], '# tokens per sequence', 'count', src, tgt);
10.5.3。固定長度的加載序列
回想一下,在語言建模中,每個(gè)示例序列(一個(gè)句子的一部分或多個(gè)句子的跨度)都有固定的長度。這是由第 9.3 節(jié)num_steps中的(時(shí)間步數(shù)或標(biāo)記數(shù))參數(shù)指定的。在機(jī)器翻譯中,每個(gè)示例都是一對源文本序列和目標(biāo)文本序列,其中這兩個(gè)文本序列可能具有不同的長度。
為了計(jì)算效率,我們?nèi)匀豢梢酝ㄟ^截?cái)嗪吞畛湟淮翁幚硪恍∨谋拘蛄小<僭O(shè)同一個(gè)小批量中的每個(gè)序列都應(yīng)該具有相同的長度 num_steps。如果文本序列少于num_steps標(biāo)記,我們將繼續(xù)在其末尾附加特殊的“”標(biāo)記,直到其長度達(dá)到num_steps. num_steps否則,我們將通過僅獲取其第一個(gè)標(biāo)記并丟棄其余標(biāo)記來截?cái)辔谋拘蛄?。這樣,每個(gè)文本序列將具有相同的長度,以相同形狀的小批量加載。此外,我們還記錄了不包括填充標(biāo)記的源序列長度。我們稍后將介紹的某些模型將需要此信息。
由于機(jī)器翻譯數(shù)據(jù)集由多對語言組成,我們可以分別為源語言和目標(biāo)語言構(gòu)建兩個(gè)詞匯表。使用詞級標(biāo)記化,詞匯量將明顯大于使用字符級標(biāo)記化的詞匯量。為了減輕這一點(diǎn),在這里,我們將看上去不到2倍相同的未知數(shù)(“”)令牌視為罕見的令牌。正如我們稍后將解釋的(圖 10.7.1),當(dāng)使用目標(biāo)序列進(jìn)行訓(xùn)練時(shí),解碼器輸出(標(biāo)簽標(biāo)記)可以是相同的解碼器輸入(目標(biāo)標(biāo)記),移動(dòng)一個(gè)標(biāo)記;特殊的序列開頭“”標(biāo)記將用作預(yù)測目標(biāo)序列的第一個(gè)輸入標(biāo)記(圖 10.7.3)。
@d2l.add_to_class(MTFraEng) #@save def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download()) @d2l.add_to_class(MTFraEng) #@save def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None): def _build_array(sentences, vocab, is_tgt=False): pad_or_trim = lambda seq, t: ( seq[:t] if len(seq) > t else seq + [''] * (t - len(seq))) sentences = [pad_or_trim(s, self.num_steps) for s in sentences] if is_tgt: sentences = [[''] + s for s in sentences] if vocab is None: vocab = d2l.Vocab(sentences, min_freq=2) array = torch.tensor([vocab[s] for s in sentences]) valid_len = (array != vocab['']).type(torch.int32).sum(1) return array, vocab, valid_len src, tgt = self._tokenize(self._preprocess(raw_text), self.num_train + self.num_val) src_array, src_vocab, src_valid_len = _build_array(src, src_vocab) tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True) return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]), src_vocab, tgt_vocab)
@d2l.add_to_class(MTFraEng) #@save def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download()) @d2l.add_to_class(MTFraEng) #@save def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None): def _build_array(sentences, vocab, is_tgt=False): pad_or_trim = lambda seq, t: ( seq[:t] if len(seq) > t else seq + [''] * (t - len(seq))) sentences = [pad_or_trim(s, self.num_steps) for s in sentences] if is_tgt: sentences = [[''] + s for s in sentences] if vocab is None: vocab = d2l.Vocab(sentences, min_freq=2) array = np.array([vocab[s] for s in sentences]) valid_len = (array != vocab['']).astype(np.int32).sum(1) return array, vocab, valid_len src, tgt = self._tokenize(self._preprocess(raw_text), self.num_train + self.num_val) src_array, src_vocab, src_valid_len = _build_array(src, src_vocab) tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True) return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]), src_vocab, tgt_vocab)
@d2l.add_to_class(MTFraEng) #@save def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download()) @d2l.add_to_class(MTFraEng) #@save def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None): def _build_array(sentences, vocab, is_tgt=False): pad_or_trim = lambda seq, t: ( seq[:t] if len(seq) > t else seq + [''] * (t - len(seq))) sentences = [pad_or_trim(s, self.num_steps) for s in sentences] if is_tgt: sentences = [[''] + s for s in sentences] if vocab is None: vocab = d2l.Vocab(sentences, min_freq=2) array = jnp.array([vocab[s] for s in sentences]) valid_len = (array != vocab['']).astype(jnp.int32).sum(1) return array, vocab, valid_len src, tgt = self._tokenize(self._preprocess(raw_text), self.num_train + self.num_val) src_array, src_vocab, src_valid_len = _build_array(src, src_vocab) tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True) return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]), src_vocab, tgt_vocab)
@d2l.add_to_class(MTFraEng) #@save def __init__(self, batch_size, num_steps=9, num_train=512, num_val=128): super(MTFraEng, self).__init__() self.save_hyperparameters() self.arrays, self.src_vocab, self.tgt_vocab = self._build_arrays( self._download()) @d2l.add_to_class(MTFraEng) #@save def _build_arrays(self, raw_text, src_vocab=None, tgt_vocab=None): def _build_array(sentences, vocab, is_tgt=False): pad_or_trim = lambda seq, t: ( seq[:t] if len(seq) > t else seq + [''] * (t - len(seq))) sentences = [pad_or_trim(s, self.num_steps) for s in sentences] if is_tgt: sentences = [[''] + s for s in sentences] if vocab is None: vocab = d2l.Vocab(sentences, min_freq=2) array = tf.constant([vocab[s] for s in sentences]) valid_len = tf.reduce_sum( tf.cast(array != vocab[''], tf.int32), 1) return array, vocab, valid_len src, tgt = self._tokenize(self._preprocess(raw_text), self.num_train + self.num_val) src_array, src_vocab, src_valid_len = _build_array(src, src_vocab) tgt_array, tgt_vocab, _ = _build_array(tgt, tgt_vocab, True) return ((src_array, tgt_array[:,:-1], src_valid_len, tgt_array[:,1:]), src_vocab, tgt_vocab)
10.5.4。讀取數(shù)據(jù)集
最后,我們定義get_dataloader返回?cái)?shù)據(jù)迭代器的方法。
@d2l.add_to_class(MTFraEng) #@save def get_dataloader(self, train): idx = slice(0, self.num_train) if train else slice(self.num_train, None) return self.get_tensorloader(self.arrays, train, idx)
讓我們從英法數(shù)據(jù)集中讀取第一個(gè)小批量。
data = MTFraEng(batch_size=3) src, tgt, src_valid_len, label = next(iter(data.train_dataloader())) print('source:', src.type(torch.int32)) print('decoder input:', tgt.type(torch.int32)) print('source len excluding pad:', src_valid_len.type(torch.int32)) print('label:', label.type(torch.int32))
source: tensor([[ 83, 174, 2, 3, 4, 4, 4, 4, 4], [ 84, 32, 91, 2, 3, 4, 4, 4, 4], [144, 174, 0, 3, 4, 4, 4, 4, 4]], dtype=torch.int32) decoder input: tensor([[ 3, 6, 0, 4, 5, 5, 5, 5, 5], [ 3, 108, 112, 84, 2, 4, 5, 5, 5], [ 3, 87, 0, 4, 5, 5, 5, 5, 5]], dtype=torch.int32) source len excluding pad: tensor([4, 5, 4], dtype=torch.int32) label: tensor([[ 6, 0, 4, 5, 5, 5, 5, 5, 5], [108, 112, 84, 2, 4, 5, 5, 5, 5], [ 87, 0, 4, 5, 5, 5, 5, 5, 5]], dtype=torch.int32)
data = MTFraEng(batch_size=3) src, tgt, src_valid_len, label = next(iter(data.train_dataloader())) print('source:', src.astype(np.int32)) print('decoder input:', tgt.astype(np.int32)) print('source len excluding pad:', src_valid_len.astype(np.int32)) print('label:', label.astype(np.int32))
source: [[84 5 2 3 4 4 4 4 4] [84 5 2 3 4 4 4 4 4] [59 9 2 3 4 4 4 4 4]] decoder input: [[ 3 108 6 2 4 5 5 5 5] [ 3 105 6 2 4 5 5 5 5] [ 3 203 0 4 5 5 5 5 5]] source len excluding pad: [4 4 4] label: [[108 6 2 4 5 5 5 5 5] [105 6 2 4 5 5 5 5 5] [203 0 4 5 5 5 5 5 5]]
data = MTFraEng(batch_size=3) src, tgt, src_valid_len, label = next(iter(data.train_dataloader())) print('source:', src.astype(jnp.int32)) print('decoder input:', tgt.astype(jnp.int32)) print('source len excluding pad:', src_valid_len.astype(jnp.int32)) print('label:', label.astype(jnp.int32))
source: [[ 86 43 2 3 4 4 4 4 4] [176 165 2 3 4 4 4 4 4] [143 111 2 3 4 4 4 4 4]] decoder input: [[ 3 108 183 98 2 4 5 5 5] [ 3 6 42 0 4 5 5 5 5] [ 3 6 0 4 5 5 5 5 5]] source len excluding pad: [4 4 4] label: [[108 183 98 2 4 5 5 5 5] [ 6 42 0 4 5 5 5 5 5] [ 6 0 4 5 5 5 5 5 5]]
data = MTFraEng(batch_size=3) src, tgt, src_valid_len, label = next(iter(data.train_dataloader())) print('source:', tf.cast(src, tf.int32)) print('decoder input:', tf.cast(tgt, tf.int32)) print('source len excluding pad:', tf.cast(src_valid_len, tf.int32)) print('label:', tf.cast(label, tf.int32))
source: tf.Tensor( [[ 92 115 2 3 4 4 4 4 4] [155 168 2 3 4 4 4 4 4] [ 86 121 2 3 4 4 4 4 4]], shape=(3, 9), dtype=int32) decoder input: tf.Tensor( [[ 3 37 6 2 4 5 5 5 5] [ 3 6 192 2 4 5 5 5 5] [ 3 108 202 30 2 4 5 5 5]], shape=(3, 9), dtype=int32) source len excluding pad: tf.Tensor([4 4 4], shape=(3,), dtype=int32) label: tf.Tensor( [[ 37 6 2 4 5 5 5 5 5] [ 6 192 2 4 5 5 5 5 5] [108 202 30 2 4 5 5 5 5]], shape=(3, 9), dtype=int32)
下面我們展示了一對由上述方法處理的源序列和目標(biāo)序列_build_arrays(以字符串格式)。
@d2l.add_to_class(MTFraEng) #@save def build(self, src_sentences, tgt_sentences): raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip( src_sentences, tgt_sentences)]) arrays, _, _ = self._build_arrays( raw_text, self.src_vocab, self.tgt_vocab) return arrays src, tgt, _, _ = data.build(['hi .'], ['salut .']) print('source:', data.src_vocab.to_tokens(src[0].type(torch.int32))) print('target:', data.tgt_vocab.to_tokens(tgt[0].type(torch.int32)))
source: ['hi', '.', '', '', '', '', '', '', ''] target: ['', 'salut', '.', '', '', '', '', '', '']
@d2l.add_to_class(MTFraEng) #@save def build(self, src_sentences, tgt_sentences): raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip( src_sentences, tgt_sentences)]) arrays, _, _ = self._build_arrays( raw_text, self.src_vocab, self.tgt_vocab) return arrays src, tgt, _, _ = data.build(['hi .'], ['salut .']) print('source:', data.src_vocab.to_tokens(src[0].astype(np.int32))) print('target:', data.tgt_vocab.to_tokens(tgt[0].astype(np.int32)))
source: ['hi', '.', '', '', '', '', '', '', ''] target: ['', 'salut', '.', '', '', '', '', '', '']
@d2l.add_to_class(MTFraEng) #@save def build(self, src_sentences, tgt_sentences): raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip( src_sentences, tgt_sentences)]) arrays, _, _ = self._build_arrays( raw_text, self.src_vocab, self.tgt_vocab) return arrays src, tgt, _, _ = data.build(['hi .'], ['salut .']) print('source:', data.src_vocab.to_tokens(src[0].astype(jnp.int32))) print('target:', data.tgt_vocab.to_tokens(tgt[0].astype(jnp.int32)))
source: ['hi', '.', '', '', '', '', '', '', ''] target: ['', 'salut', '.', '', '', '', '', '', '']
@d2l.add_to_class(MTFraEng) #@save def build(self, src_sentences, tgt_sentences): raw_text = 'n'.join([src + 't' + tgt for src, tgt in zip( src_sentences, tgt_sentences)]) arrays, _, _ = self._build_arrays( raw_text, self.src_vocab, self.tgt_vocab) return arrays src, tgt, _, _ = data.build(['hi .'], ['salut .']) print('source:', data.src_vocab.to_tokens(tf.cast(src[0], tf.int32))) print('target:', data.tgt_vocab.to_tokens(tf.cast(tgt[0], tf.int32)))
source: ['hi', '.', '', '', '', '', '', '', ''] target: ['', 'salut', '.', '', '', '', '', '', '']
10.5.5。概括
在自然語言處理中,機(jī)器翻譯是指將代表源語言文本字符串的序列自動(dòng)映射到代表目標(biāo)語言中合理翻譯的字符串的任務(wù)。使用詞級標(biāo)記化,詞匯量將明顯大于使用字符級標(biāo)記化,但序列長度會(huì)短得多。為了減輕大詞匯量,我們可以將不常見的標(biāo)記視為一些“未知”標(biāo)記。我們可以截?cái)嗪吞畛湮谋拘蛄?,以便它們都具有相同的長度以加載到小批量中。現(xiàn)代實(shí)現(xiàn)通常將具有相似長度的序列存儲(chǔ)起來,以避免在填充上浪費(fèi)過多的計(jì)算。
10.5.6。練習(xí)
max_examples在 方法中嘗試不同的參數(shù)值_tokenize。這如何影響源語言和目標(biāo)語言的詞匯量?
某些語言(例如中文和日文)中的文本沒有字界指示符(例如,空格)。對于這種情況,詞級標(biāo)記化仍然是一個(gè)好主意嗎?為什么或者為什么不?
-
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1209瀏覽量
24772 -
pytorch
+關(guān)注
關(guān)注
2文章
808瀏覽量
13314
發(fā)布評論請先 登錄
相關(guān)推薦
評論