旋轉(zhuǎn)式位置編碼(RoPE)最早是論文[1]提出的一種能夠?qū)⑾鄬?duì)位置信息依賴集成到 self-attention 中并提升 transformer 架構(gòu)性能的位置編碼方式。而目前很火的 LLaMA 模型也是采用該位置編碼方式。
接下來(lái)結(jié)合代碼和論文來(lái)解讀一下 RoPE。
基本概念
首先論文中定義一個(gè)長(zhǎng)度為 N 的輸入序列為:
其中 wi 表示輸入序列中第 i 個(gè) token,而輸入序列 SN 對(duì)應(yīng)的 embedding 表示為:
其中 xi 表示第 i 個(gè) token wi 對(duì)應(yīng)的 d 維詞嵌入向量。
接著在做 self-attention 之前,會(huì)用詞嵌入向量計(jì)算 q, k, v 向量同時(shí)加入位置信息,函數(shù)公式表達(dá)如下:
其中 qm 表示第 m 個(gè) token 對(duì)應(yīng)的詞向量 xm 集成位置信息 m 之后的 query 向量。而 kn 和 vn 則表示第 n 個(gè) token 對(duì)應(yīng)的詞向量 xn 集成位置信息 n 之后的 key 和 value 向量。
而基于 transformer 的位置編碼方法都是著重于構(gòu)造一個(gè)合適的 f{q,k,v} 函數(shù)形式。
而計(jì)算第 m 個(gè)詞嵌入向量 xm 對(duì)應(yīng)的 self-attention 輸出結(jié)果,就是 qm 和其他 kn 都計(jì)算一個(gè) attention score ,然后再將 attention score 乘以對(duì)應(yīng)的 vn 再求和得到輸出向量 om:
絕對(duì)位置編碼
對(duì)于位置編碼,常規(guī)的做法是在計(jì)算 query, key 和 value 向量之前,會(huì)計(jì)算一個(gè)位置編碼向量 pi 加到詞嵌入 xi 上,位置編碼向量 pi 同樣也是 d 維向量,然后再乘以對(duì)應(yīng)的變換矩陣 W{q,k,v}:
而經(jīng)典的位置編碼向量 pi 的計(jì)算方式是:
其中 p_{i,2t} 表示位置 d 維度向量 pi 中的第 2t 個(gè)元素也就是偶數(shù)索引位置的計(jì)算公式,而 p_{i,2t+1} 就對(duì)應(yīng)奇數(shù)索引位置的計(jì)算公式。
python 代碼如下:
#?position?就對(duì)應(yīng)?token?序列中的位置索引?i #?hidden_dim?就對(duì)應(yīng)詞嵌入維度大小?d #?seq_len?表示?token?序列長(zhǎng)度 def?get_position_angle_vec(position): ????return?[position?/?np.power(10000,?2?*?(hid_j?//?2)?/?hidden_dim)?for?hid_j?in?range(hidden_dim)] #?position_angle_vecs.shape?=?[seq_len,?hidden_dim] position_angle_vecs?=?np.array([get_position_angle_vec(pos_i)?for?pos_i?in?range(seq_len)]) #?分別計(jì)算奇偶索引位置對(duì)應(yīng)的?sin?和?cos?值 position_angle_vecs[:,?0::2]?=?np.sin(position_angle_vecs[:,?0::2])??#?dim?2t position_angle_vecs[:,?1::2]?=?np.cos(position_angle_vecs[:,?1::2])??#?dim?2t+1 #?positional_embeddings.shape?=?[1,?seq_len,?hidden_dim] positional_embeddings?=?torch.FloatTensor(position_angle_vecs).unsqueeze(0)
旋轉(zhuǎn)式位置編碼
接著論文中提出為了能利用上 token 之間的相對(duì)位置信息,假定 query 向量 qm 和 key 向量 kn 之間的內(nèi)積操作可以被一個(gè)函數(shù) g 表示,該函數(shù) g 的輸入是詞嵌入向量 xm , xn 和它們之間的相對(duì)位置 m - n:
接下來(lái)的目標(biāo)就是找到一個(gè)等價(jià)的位置編碼方式,從而使得上述關(guān)系成立。
假定現(xiàn)在詞嵌入向量的維度是兩維 d=2,這樣就可以利用上2維度平面上的向量的幾何性質(zhì),然后論文中提出了一個(gè)滿足上述關(guān)系的 f 和 g 的形式如下:
上面的公式一眼看過(guò)去感覺(jué)很復(fù)雜,怎么理解呢?
首先我們得先了解一下基本的復(fù)數(shù)相關(guān)知識(shí)。
首先看到上述 f 和 g ?公式中有個(gè)指數(shù)函數(shù)
這個(gè)其實(shí)是歐拉公式 [2],其中 x 表示任意實(shí)數(shù), e 是自然對(duì)數(shù)的底數(shù),i 是復(fù)數(shù)中的虛數(shù)單位,則根據(jù)歐拉公式有:
上述指數(shù)函數(shù)可以表示為實(shí)部為 cosx,虛部為 sinx 的一個(gè)復(fù)數(shù),歐拉公式 [2] 建立了指數(shù)函數(shù)、三角函數(shù)和復(fù)數(shù)之間的橋梁。
則上述 f 和 g ?公式中的
然后我們看回公式:
其中 Wq 是個(gè)二維矩陣,xm 是個(gè)二維向量,相乘的結(jié)果也是一個(gè)二維向量,這里用 qm 表示:
然后首先將 qm 表示成復(fù)數(shù)形式:
接著
其實(shí)就是兩個(gè)復(fù)數(shù)相乘:
我們首先來(lái)復(fù)習(xí)一下復(fù)數(shù)乘法的性質(zhì):
可以看到,復(fù)數(shù)乘法也是用的分配律,還有用到了復(fù)數(shù)的一個(gè)性質(zhì):
然后就有:
將結(jié)果重新表達(dá)成實(shí)數(shù)向量形式就是:
相信讀者看到這里會(huì)發(fā)現(xiàn)這不就是 query 向量乘以了一個(gè)旋轉(zhuǎn)矩陣[5]嗎?
這就是為什么叫做旋轉(zhuǎn)式位置編碼的原因。
同理可得 key 向量 kn :
最后還有個(gè)函數(shù) g:
其中 Re[x] 表示一個(gè)復(fù)數(shù) x 的實(shí)部部分,而
則表示復(fù)數(shù)
的共軛,復(fù)習(xí)一下共軛復(fù)數(shù)的定義:
所以可得:
繼續(xù)可得:
ok,接下來(lái)我們就要證明函數(shù) g 的計(jì)算公式是成立的。
首先回顧一下 attention 操作, 位置 m 的 query 和位置 n 的 key 會(huì)做一個(gè)內(nèi)積操作:
接著繼續(xù)之前先復(fù)習(xí)一下三角函數(shù)的一些性質(zhì)[3]:
好了回到上面那坨式子,我們整理一下:
這就證明上述關(guān)系是成立的,位置 m 的 query 和位置 n 的 key 的內(nèi)積就是函數(shù) g。
然后上面的講解是假定的詞嵌入維度是2維向量,而對(duì)于d >= 2 的通用情況,則是將詞嵌入向量元素按照兩兩一組分組,每組應(yīng)用同樣的旋轉(zhuǎn)操作且每組的旋轉(zhuǎn)角度計(jì)算方式如下:
所以簡(jiǎn)單來(lái)說(shuō) RoPE 的 self-attention 操作的流程是,對(duì)于 token 序列中的每個(gè)詞嵌入向量,首先計(jì)算其對(duì)應(yīng)的 query 和 key 向量,然后對(duì)每個(gè) token 位置都計(jì)算對(duì)應(yīng)的旋轉(zhuǎn)位置編碼,接著對(duì)每個(gè) token 位置的 query 和 key 向量的元素按照 兩兩一組 應(yīng)用旋轉(zhuǎn)變換,最后再計(jì)算 query 和 key 之間的內(nèi)積得到 self-attention 的計(jì)算結(jié)果。
論文中有個(gè)很直觀的圖片展示了旋轉(zhuǎn)變換的過(guò)程:
LLaMA 官方實(shí)現(xiàn)代碼 [4] 如下(經(jīng)過(guò)簡(jiǎn)化):
?
?
def?precompute_freqs_cis(dim:?int,?seq_len:?int,?theta:?float?=?10000.0): ????#?計(jì)算詞向量元素兩兩分組之后,每組元素對(duì)應(yīng)的旋轉(zhuǎn)角度 ????freqs?=?1.0?/?(theta?**?(torch.arange(0,?dim,?2)[:?(dim?//?2)].float()?/?dim)) ????#?生成?token?序列索引?t?=?[0,?1,...,?seq_len-1] ????t?=?torch.arange(seq_len,?device=freqs.device) ????#?freqs.shape?=?[seq_len,?dim?//?2]? ????freqs?=?torch.outer(t,?freqs).float() ????#?torch.polar?的文檔 ????#?https://pytorch.org/docs/stable/generated/torch.polar.html ????#?計(jì)算結(jié)果是個(gè)復(fù)數(shù)向量 ????#?假設(shè)?freqs?=?[x,?y] ????#?則?freqs_cis?=?[cos(x)?+?sin(x)i,?cos(y)?+?sin(y)i] ????freqs_cis?=?torch.polar(torch.ones_like(freqs),?freqs) ????return?freqs_cis def?apply_rotary_emb( ????xq:?torch.Tensor, ????xk:?torch.Tensor, ????freqs_cis:?torch.Tensor, )?->?Tuple[torch.Tensor,?torch.Tensor]: ????#?xq.shape?=?[batch_size,?seq_len,?dim] ????#?xq_.shape?=?[batch_size,?seq_len,?dim?//?2,?2] ????xq_?=?xq.float().reshape(*xq.shape[:-1],?-1,?2) ????xk_?=?xk.float().reshape(*xk.shape[:-1],?-1,?2) ???? ????#?轉(zhuǎn)為復(fù)數(shù)域 ????xq_?=?torch.view_as_complex(xq_) ????xk_?=?torch.view_as_complex(xk_) ???? ????#?應(yīng)用旋轉(zhuǎn)操作,然后將結(jié)果轉(zhuǎn)回實(shí)數(shù)域 ????#?xq_out.shape?=?[batch_size,?seq_len,?dim] ????xq_out?=?torch.view_as_real(xq_?*?freqs_cis).flatten(2) ????xk_out?=?torch.view_as_real(xk_?*?freqs_cis).flatten(2) ????return?xq_out.type_as(xq),?xk_out.type_as(xk) class?Attention(nn.Module): ????def?__init__(self,?args:?ModelArgs): ????????super().__init__() ????????self.wq?=?Linear(...) ????????self.wk?=?Linear(...) ????????self.wv?=?Linear(...) ???????? ????????self.freqs_cis?=?precompute_freqs_cis(dim,?max_seq_len?*?2) ????def?forward(self,?x:?torch.Tensor): ????????bsz,?seqlen,?_?=?x.shape ????????xq,?xk,?xv?=?self.wq(x),?self.wk(x),?self.wv(x) ????????xq?=?xq.view(batch_size,?seq_len,?dim) ????????xk?=?xk.view(batch_size,?seq_len,?dim) ????????xv?=?xv.view(batch_size,?seq_len,?dim) ????????#?attention?操作之前,應(yīng)用旋轉(zhuǎn)位置編碼 ????????xq,?xk?=?apply_rotary_emb(xq,?xk,?freqs_cis=freqs_cis) ???????? ????????#?scores.shape?=?(batch_size,?seq_len,?seqlen) ????????scores?=?torch.matmul(xq,?xk.transpose(1,?2))?/?math.sqrt(dim) ????????scores?=?F.softmax(scores.float(),?dim=-1) ????????output?=?torch.matmul(scores,?xv)??#?(batch_size,?seq_len,?dim) ??#?......
可以看到 LLaMA 的官方實(shí)現(xiàn)代碼和論文 [1] 中的描述是一致的。
編輯:黃飛
?
評(píng)論
查看更多