在深度學(xué)習(xí)中,我們經(jīng)常使用 CNN 或 RNN 對(duì)序列進(jìn)行編碼。現(xiàn)在考慮到注意力機(jī)制,想象一下將一系列標(biāo)記輸入注意力機(jī)制,這樣在每個(gè)步驟中,每個(gè)標(biāo)記都有自己的查詢、鍵和值。在這里,當(dāng)在下一層計(jì)算令牌表示的值時(shí),令牌可以(通過其查詢向量)參與每個(gè)其他令牌(基于它們的鍵向量進(jìn)行匹配)。使用完整的查詢鍵兼容性分?jǐn)?shù)集,我們可以通過在其他標(biāo)記上構(gòu)建適當(dāng)?shù)募訖?quán)和來為每個(gè)標(biāo)記計(jì)算表示。因?yàn)槊總€(gè)標(biāo)記都關(guān)注另一個(gè)標(biāo)記(不同于解碼器步驟關(guān)注編碼器步驟的情況),這種架構(gòu)通常被描述為自注意力模型 (Lin等。, 2017 年, Vaswani等人。, 2017 ),以及其他地方描述的內(nèi)部注意力模型 ( Cheng et al. , 2016 , Parikh et al. , 2016 , Paulus et al. , 2017 )。在本節(jié)中,我們將討論使用自注意力的序列編碼,包括使用序列順序的附加信息。
11.6.1。自注意力
給定一系列輸入標(biāo)記 x1,…,xn任何地方 xi∈Rd(1≤i≤n), 它的self-attention輸出一個(gè)相同長(zhǎng)度的序列 y1,…,yn, 在哪里
根據(jù) (11.1.1)中attention pooling的定義。使用多頭注意力,以下代碼片段計(jì)算具有形狀(批量大小、時(shí)間步數(shù)或標(biāo)記中的序列長(zhǎng)度, d). 輸出張量具有相同的形狀。
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
11.6.2。比較 CNN、RNN 和自注意力
讓我們比較一下映射一系列的架構(gòu)n標(biāo)記到另一個(gè)等長(zhǎng)序列,其中每個(gè)輸入或輸出標(biāo)記由一個(gè)d維向量。具體來說,我們將考慮 CNN、RNN 和自注意力。我們將比較它們的計(jì)算復(fù)雜度、順序操作和最大路徑長(zhǎng)度。請(qǐng)注意,順序操作會(huì)阻止并行計(jì)算,而序列位置的任意組合之間的較短路徑可以更容易地學(xué)習(xí)序列內(nèi)的遠(yuǎn)程依賴關(guān)系 (Hochreiter等人,2001 年)。
考慮一個(gè)卷積層,其內(nèi)核大小為k. 我們將在后面的章節(jié)中提供有關(guān)使用 CNN 進(jìn)行序列處理的更多詳細(xì)信息。現(xiàn)在,我們只需要知道,因?yàn)樾蛄虚L(zhǎng)度是n,輸入和輸出通道的數(shù)量都是 d, 卷積層的計(jì)算復(fù)雜度為 O(knd2). 如圖11.6.1 所示,CNN 是分層的,因此有O(1) 順序操作和最大路徑長(zhǎng)度是
評(píng)論
查看更多