0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

PyTorch教程-10.6. 編碼器-解碼器架構(gòu)

jf_pJlTbmA9 ? 來源:PyTorch ? 作者:PyTorch ? 2023-06-05 15:44 ? 次閱讀

在一般的 seq2seq 問題中,如機(jī)器翻譯(第 10.5 節(jié)),輸入和輸出的長度不同且未對(duì)齊。處理這類數(shù)據(jù)的標(biāo)準(zhǔn)方法是設(shè)計(jì)一個(gè)編碼器-解碼器架構(gòu)(圖 10.6.1),它由兩個(gè)主要組件組成:一個(gè) 編碼器,它以可變長度序列作為輸入,以及一個(gè) 解碼器,作為一個(gè)條件語言模型,接收編碼輸入和目標(biāo)序列的向左上下文,并預(yù)測目標(biāo)序列中的后續(xù)標(biāo)記。

poYBAGR9N3SADJcWAACS_jUPoZI623.svg

圖 10.6.1編碼器-解碼器架構(gòu)。

讓我們以從英語到法語的機(jī)器翻譯為例。給定一個(gè)英文輸入序列:“They”、“are”、“watching”、“.”,這種編碼器-解碼器架構(gòu)首先將可變長度輸入編碼為一個(gè)狀態(tài),然后對(duì)該狀態(tài)進(jìn)行解碼以生成翻譯后的序列,token通過標(biāo)記,作為輸出:“Ils”、“regardent”、“.”。由于編碼器-解碼器架構(gòu)構(gòu)成了后續(xù)章節(jié)中不同 seq2seq 模型的基礎(chǔ),因此本節(jié)將此架構(gòu)轉(zhuǎn)換為稍后將實(shí)現(xiàn)的接口。

from torch import nn
from d2l import torch as d2l

from mxnet.gluon import nn
from d2l import mxnet as d2l

from flax import linen as nn
from d2l import jax as d2l

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

import tensorflow as tf
from d2l import tensorflow as d2l

10.6.1。編碼器

在編碼器接口中,我們只是指定編碼器將可變長度序列作為輸入X。實(shí)現(xiàn)將由繼承此基類的任何模型提供Encoder。

class Encoder(nn.Module): #@save
  """The base encoder interface for the encoder-decoder architecture."""
  def __init__(self):
    super().__init__()

  # Later there can be additional arguments (e.g., length excluding padding)
  def forward(self, X, *args):
    raise NotImplementedError

class Encoder(nn.Block): #@save
  """The base encoder interface for the encoder-decoder architecture."""
  def __init__(self):
    super().__init__()

  # Later there can be additional arguments (e.g., length excluding padding)
  def forward(self, X, *args):
    raise NotImplementedError

class Encoder(nn.Module): #@save
  """The base encoder interface for the encoder-decoder architecture."""
  def setup(self):
    raise NotImplementedError

  # Later there can be additional arguments (e.g., length excluding padding)
  def __call__(self, X, *args):
    raise NotImplementedError

class Encoder(tf.keras.layers.Layer): #@save
  """The base encoder interface for the encoder-decoder architecture."""
  def __init__(self):
    super().__init__()

  # Later there can be additional arguments (e.g., length excluding padding)
  def call(self, X, *args):
    raise NotImplementedError

10.6.2。解碼器

在下面的解碼器接口中,我們添加了一個(gè)額外的init_state 方法來將編碼器輸出 ( enc_all_outputs) 轉(zhuǎn)換為編碼狀態(tài)。請(qǐng)注意,此步驟可能需要額外的輸入,例如輸入的有效長度,這在 第 10.5 節(jié)中有解釋。為了逐個(gè)令牌生成可變長度序列令牌,每次解碼器都可以將輸入(例如,在先前時(shí)間步生成的令牌)和編碼狀態(tài)映射到當(dāng)前時(shí)間步的輸出令牌。

class Decoder(nn.Module): #@save
  """The base decoder interface for the encoder-decoder architecture."""
  def __init__(self):
    super().__init__()

  # Later there can be additional arguments (e.g., length excluding padding)
  def init_state(self, enc_all_outputs, *args):
    raise NotImplementedError

  def forward(self, X, state):
    raise NotImplementedError

class Decoder(nn.Block): #@save
  """The base decoder interface for the encoder-decoder architecture."""
  def __init__(self):
    super().__init__()

  # Later there can be additional arguments (e.g., length excluding padding)
  def init_state(self, enc_all_outputs, *args):
    raise NotImplementedError

  def forward(self, X, state):
    raise NotImplementedError

class Decoder(nn.Module): #@save
  """The base decoder interface for the encoder-decoder architecture."""
  def setup(self):
    raise NotImplementedError

  # Later there can be additional arguments (e.g., length excluding padding)
  def init_state(self, enc_all_outputs, *args):
    raise NotImplementedError

  def __call__(self, X, state):
    raise NotImplementedError

class Decoder(tf.keras.layers.Layer): #@save
  """The base decoder interface for the encoder-decoder architecture."""
  def __init__(self):
    super().__init__()

  # Later there can be additional arguments (e.g., length excluding padding)
  def init_state(self, enc_all_outputs, *args):
    raise NotImplementedError

  def call(self, X, state):
    raise NotImplementedError

10.6.3。將編碼器和解碼器放在一起

在前向傳播中,編碼器的輸出用于產(chǎn)生編碼狀態(tài),解碼器將進(jìn)一步使用該狀態(tài)作為其輸入之一。

class EncoderDecoder(d2l.Classifier): #@save
  """The base class for the encoder-decoder architecture."""
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, enc_X, dec_X, *args):
    enc_all_outputs = self.encoder(enc_X, *args)
    dec_state = self.decoder.init_state(enc_all_outputs, *args)
    # Return decoder output only
    return self.decoder(dec_X, dec_state)[0]

class EncoderDecoder(d2l.Classifier): #@save
  """The base class for the encoder-decoder architecture."""
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, enc_X, dec_X, *args):
    enc_all_outputs = self.encoder(enc_X, *args)
    dec_state = self.decoder.init_state(enc_all_outputs, *args)
    # Return decoder output only
    return self.decoder(dec_X, dec_state)[0]

class EncoderDecoder(d2l.Classifier): #@save
  """The base class for the encoder-decoder architecture."""
  encoder: nn.Module
  decoder: nn.Module
  training: bool

  def __call__(self, enc_X, dec_X, *args):
    enc_all_outputs = self.encoder(enc_X, *args, training=self.training)
    dec_state = self.decoder.init_state(enc_all_outputs, *args)
    # Return decoder output only
    return self.decoder(dec_X, dec_state, training=self.training)[0]

class EncoderDecoder(d2l.Classifier): #@save
  """The base class for the encoder-decoder architecture."""
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def call(self, enc_X, dec_X, *args):
    enc_all_outputs = self.encoder(enc_X, *args, training=True)
    dec_state = self.decoder.init_state(enc_all_outputs, *args)
    # Return decoder output only
    return self.decoder(dec_X, dec_state, training=True)[0]

在下一節(jié)中,我們將看到如何應(yīng)用 RNN 來設(shè)計(jì)基于這種編碼器-解碼器架構(gòu)的 seq2seq 模型。

10.6.4。概括

編碼器-解碼器架構(gòu)可以處理由可變長度序列組成的輸入和輸出,因此適用于機(jī)器翻譯等 seq2seq 問題。編碼器將可變長度序列作為輸入,并將其轉(zhuǎn)換為具有固定形狀的狀態(tài)。解碼器將固定形狀的編碼狀態(tài)映射到可變長度序列。

10.6.5。練習(xí)

假設(shè)我們使用神經(jīng)網(wǎng)絡(luò)來實(shí)現(xiàn)編碼器-解碼器架構(gòu)。編碼器和解碼器必須是同一類型的神經(jīng)網(wǎng)絡(luò)嗎?

除了機(jī)器翻譯,你能想到另一個(gè)可以應(yīng)用編碼器-解碼器架構(gòu)的應(yīng)用程序嗎?

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 解碼器
    +關(guān)注

    關(guān)注

    9

    文章

    1143

    瀏覽量

    40742
  • 編碼器
    +關(guān)注

    關(guān)注

    45

    文章

    3643

    瀏覽量

    134531
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    808

    瀏覽量

    13229
收藏 人收藏

    評(píng)論

    相關(guān)推薦

    怎么理解真正的編碼器解碼器?

      在進(jìn)入關(guān)于編碼器解碼器的現(xiàn)實(shí)之前,讓我們對(duì)復(fù)用進(jìn)行簡要的思考。通常我們會(huì)在需要將一些輸入信號(hào)一次一個(gè)地加載到一個(gè)單獨(dú)負(fù)載的應(yīng)用程序中。選擇輸入信號(hào)中的一個(gè)輸入信號(hào)的過程稱為多路復(fù)用。這種操作
    發(fā)表于 09-01 17:48

    編碼器解碼器的區(qū)別是什么,編碼器用軟件還是硬件好

    編碼器指的是對(duì)視頻信號(hào)進(jìn)行壓縮,解碼器主要是將壓縮的視頻信號(hào)進(jìn)行解壓縮。目前做直播的很多都是采用的編碼器,客戶端可以采用解碼器或軟件播放
    發(fā)表于 08-02 17:23 ?3.4w次閱讀

    詳解編碼器解碼器電路:定義/工作原理/應(yīng)用/真值表

    編碼器解碼器是組合邏輯電路,在其中,主要借助布爾代數(shù)實(shí)現(xiàn)組合邏輯。今天就大家了解一下編碼器解碼器電路,分別從定義,工作原理,應(yīng)用,真值表幾個(gè)方面講述一下。
    的頭像 發(fā)表于 11-03 09:22 ?7321次閱讀
    詳解<b class='flag-5'>編碼器</b>和<b class='flag-5'>解碼器</b>電路:定義/工作原理/應(yīng)用/真值表

    PyTorch教程10.6編碼器-解碼器架構(gòu)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.6編碼器-解碼器架構(gòu).pdf》資料免費(fèi)下載
    發(fā)表于 06-05 18:12 ?0次下載
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>10.6</b>之<b class='flag-5'>編碼器</b>-<b class='flag-5'>解碼器</b><b class='flag-5'>架構(gòu)</b>

    PyTorch教程10.7之用于機(jī)器翻譯的編碼器-解碼器Seq2Seq

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.7之用于機(jī)器翻譯的編碼器-解碼器Seq2Seq.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 18:14 ?0次下載
    <b class='flag-5'>PyTorch</b>教程10.7之用于機(jī)器翻譯的<b class='flag-5'>編碼器</b>-<b class='flag-5'>解碼器</b>Seq2Seq

    PyTorch教程-10.7. 用于機(jī)器翻譯的編碼器-解碼器 Seq2Seq

    序列組成,我們通常依賴編碼器-解碼器架構(gòu)(第10.6 節(jié))。在本節(jié)中,我們將演示編碼器-解碼器
    的頭像 發(fā)表于 06-05 15:44 ?784次閱讀
    <b class='flag-5'>PyTorch</b>教程-10.7. 用于機(jī)器翻譯的<b class='flag-5'>編碼器</b>-<b class='flag-5'>解碼器</b> Seq2Seq

    基于transformer的編碼器-解碼器模型的工作原理

    與基于 RNN 的編碼器-解碼器模型類似,基于 transformer 的編碼器-解碼器模型由一個(gè)編碼器和一個(gè)
    發(fā)表于 06-11 14:17 ?2259次閱讀
    基于transformer的<b class='flag-5'>編碼器</b>-<b class='flag-5'>解碼器</b>模型的工作原理

    基于 RNN 的解碼器架構(gòu)如何建模

    language processing,NLP) 領(lǐng)域編碼器-解碼器架構(gòu)的?事實(shí)標(biāo)準(zhǔn)?。 最近基于 transformer 的編碼器-解碼器
    的頭像 發(fā)表于 06-12 17:08 ?820次閱讀
    基于 RNN 的<b class='flag-5'>解碼器</b><b class='flag-5'>架構(gòu)</b>如何建模

    基于 Transformers 的編碼器-解碼器模型

    基于 transformer 的編碼器-解碼器模型是 表征學(xué)習(xí) 和 模型架構(gòu) 這兩個(gè)領(lǐng)域多年研究成果的結(jié)晶。本文簡要介紹了神經(jīng)編碼器-解碼器
    的頭像 發(fā)表于 06-16 16:53 ?893次閱讀
    基于 Transformers 的<b class='flag-5'>編碼器</b>-<b class='flag-5'>解碼器</b>模型

    神經(jīng)編碼器-解碼器模型的歷史

    基于 transformer 的編碼器-解碼器模型是 表征學(xué)習(xí) 和 模型架構(gòu) 這兩個(gè)領(lǐng)域多年研究成果的結(jié)晶。本文簡要介紹了神經(jīng)編碼器-解碼器
    的頭像 發(fā)表于 06-20 15:42 ?893次閱讀
    神經(jīng)<b class='flag-5'>編碼器</b>-<b class='flag-5'>解碼器</b>模型的歷史

    詳解編碼器解碼器電路

    編碼器解碼器是組合邏輯電路,在其中,主要借助布爾代數(shù)實(shí)現(xiàn)組合邏輯。今天就大家了解一下編碼器解碼器電路,分別從定義,工作原理,應(yīng)用,真值表幾個(gè)方面講述一下。
    的頭像 發(fā)表于 07-14 09:07 ?3256次閱讀
    詳解<b class='flag-5'>編碼器</b>和<b class='flag-5'>解碼器</b>電路

    視頻編碼器解碼器的應(yīng)用方案

    視頻解碼器和視頻編碼器在數(shù)字通訊、音視頻壓縮領(lǐng)域有著廣泛的應(yīng)用。視頻編碼器作為視頻源的發(fā)送端,若接收端如果是?PC?機(jī)或顯示設(shè)備就需要通過解碼器進(jìn)行
    的頭像 發(fā)表于 08-14 14:38 ?1353次閱讀
    視頻<b class='flag-5'>編碼器</b>與<b class='flag-5'>解碼器</b>的應(yīng)用方案

    YXC丨視頻編碼器解碼器的應(yīng)用方案

    視頻解碼器和視頻編碼器是數(shù)字信號(hào)處理中常用的設(shè)備,它們?cè)跀?shù)據(jù)的傳輸和轉(zhuǎn)換中發(fā)揮著重要作用。
    的頭像 發(fā)表于 08-23 09:40 ?675次閱讀
    YXC丨視頻<b class='flag-5'>編碼器</b>與<b class='flag-5'>解碼器</b>的應(yīng)用方案

    視頻編碼器解碼器的應(yīng)用方案

    視頻解碼器和視頻編碼器是數(shù)字信號(hào)處理中常用的設(shè)備,它們?cè)跀?shù)據(jù)的傳輸和轉(zhuǎn)換中發(fā)揮著重要作用。
    的頭像 發(fā)表于 08-28 11:31 ?595次閱讀
    視頻<b class='flag-5'>編碼器</b>與<b class='flag-5'>解碼器</b>的應(yīng)用方案

    信路達(dá) 解碼器/編碼器 XD74LS48數(shù)據(jù)手冊(cè)

    解碼器/編碼器?DIP164.75~5.25V封裝:DIP16_19.3X6.4MM
    發(fā)表于 08-19 15:57 ?2次下載