0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

深入理解BigBird的塊稀疏高效實(shí)現(xiàn)方案

OSC開(kāi)源社區(qū) ? 來(lái)源:OSC開(kāi)源社區(qū) ? 2023-11-29 11:02 ? 次閱讀

引言

基于 transformer 的模型已被證明對(duì)很多 NLP 任務(wù)都非常有用。然而, 的時(shí)間和內(nèi)存復(fù)雜度 (其中 是序列長(zhǎng)度) 使得在長(zhǎng)序列 () 上應(yīng)用它們變得非常昂貴,因而大大限制了其應(yīng)用。最近的幾篇論文,如 Longformer 、Performer 、Reformer 、簇狀注意力 都試圖通過(guò)對(duì)完整注意力矩陣進(jìn)行近似來(lái)解決這個(gè)問(wèn)題。如果你不熟悉這些模型,可以查看 之前的 博文。

BigBird (由 該論文 引入) 是解決這個(gè)問(wèn)題的最新模型之一。 BigBird 依賴于 塊稀疏注意力 而不是普通注意力 ( 即 BERT 的注意力),與 BERT 相比,這一新算法能以低得多的計(jì)算成本處理長(zhǎng)達(dá) 4096 的序列。在涉及很長(zhǎng)序列的各種任務(wù)上,該模型都實(shí)現(xiàn)了 SOTA,例如長(zhǎng)文檔摘要、長(zhǎng)上下文問(wèn)答。

RoBERTa 架構(gòu)的 BigBird 模型現(xiàn)已集成入 transformers 中。本文的目的是讓讀者 深入 了解 BigBird 的實(shí)現(xiàn),并讓讀者能在 transformers 中輕松使用 BigBird。但是,在更深入之前,一定記住 BigBird 注意力只是 BERT 完全注意力的一個(gè)近似,因此我們并不糾結(jié)于讓它比 BERT 完全注意力 更好,而是致力于讓它更有效率。有了它,transformer 模型就可以作用于更長(zhǎng)的序列,因?yàn)?BERT 的二次方內(nèi)存需求很快會(huì)變得難以為繼。簡(jiǎn)而言之,如果我們有 計(jì)算和 時(shí)間,那么用 BERT 注意力就好了,完全沒(méi)必要用本文討論的塊稀疏注意力。

如果你想知道為什么在處理較長(zhǎng)序列時(shí)需要更多計(jì)算,那么本文正合你意!

在使用標(biāo)準(zhǔn)的 BERT 類注意力時(shí)可能會(huì)遇到以下幾個(gè)主要問(wèn)題:

每個(gè)詞元真的都必須關(guān)注所有其他詞元嗎?

為什么不只計(jì)算重要詞元的注意力?

如何決定哪些詞元重要?

如何以高效的方式處理少量詞元?

本文,我們將嘗試回答這些問(wèn)題。

應(yīng)該關(guān)注哪些詞元?

下面,我們將以句子 BigBird is now available in HuggingFace for extractive Question Answering 為例來(lái)說(shuō)明注意力是如何工作的。在 BERT 這類的注意力機(jī)制中,每個(gè)詞元都簡(jiǎn)單粗暴地關(guān)注所有其他詞元。從數(shù)學(xué)上來(lái)講,這意味著每個(gè)查詢的詞元,將關(guān)注每個(gè)鍵詞元。

我們考慮一下 每個(gè)查詢?cè)~元應(yīng)如何明智地選擇它實(shí)際上應(yīng)該關(guān)注的鍵詞元 這個(gè)問(wèn)題,下面我們通過(guò)編寫(xiě)偽代碼的方式來(lái)整理思考過(guò)程。

假設(shè) available 是當(dāng)前查詢?cè)~元,我們來(lái)構(gòu)建一個(gè)合理的、需要關(guān)注的鍵詞元列表。

#以下面的句子為例
example=['BigBird','is','now','available','in','HuggingFace','for','extractive','question','answering']

#假設(shè)當(dāng)前需要計(jì)算'available'這個(gè)詞的表征
query_token='available'

#初始化一個(gè)空集合,用于放'available'這個(gè)詞的鍵詞元
key_tokens=[]#=>目前,'available'詞元不關(guān)注任何詞元

鄰近詞元當(dāng)然很重要,因?yàn)樵谝粋€(gè)句子 (單詞序列) 中,當(dāng)前詞高度依賴于前后的鄰近詞?;瑒?dòng)注意力 即基于該直覺(jué)。

#考慮滑動(dòng)窗大小為3,即將'available'的左邊一個(gè)詞和右邊一個(gè)詞納入考量
#左詞:'now';右詞:'in'
sliding_tokens=["now","available","in"]

#用以上詞元更新集合
key_tokens.append(sliding_tokens)

長(zhǎng)程依賴關(guān)系: 對(duì)某些任務(wù)而言,捕獲詞元間的長(zhǎng)程關(guān)系至關(guān)重要。 例如 ,在問(wèn)答類任務(wù)中,模型需要將上下文的每個(gè)詞元與整個(gè)問(wèn)題進(jìn)行比較,以便能夠找出上下文的哪一部分對(duì)正確答案有用。如果大多數(shù)上下文詞元僅關(guān)注其他上下文詞元,而不關(guān)注問(wèn)題,那么模型從不太重要的上下文詞元中過(guò)濾重要的上下文詞元就會(huì)變得更加困難。

BigBird 提出了兩種允許長(zhǎng)程注意力依賴的方法,這兩種方法都能保證計(jì)算效率。

全局詞元: 引入一些詞元,這些詞元將關(guān)注每個(gè)詞元并且被每個(gè)詞元關(guān)注。例如,對(duì) “HuggingFace is building nice libraries for easy NLP” ,現(xiàn)在假設(shè) 'building' 被定義為全局詞元,而對(duì)某些任務(wù)而言,模型需要知道 'NLP' 和 'HuggingFace' 之間的關(guān)系 (注意: 這 2 個(gè)詞元位于句子的兩端); 現(xiàn)在讓 'building' 在全局范圍內(nèi)關(guān)注所有其他詞元,會(huì)對(duì)模型將 'NLP' 與 'HuggingFace' 關(guān)聯(lián)起來(lái)有幫助。

#我們假設(shè)第一個(gè)和最后一個(gè)詞元是全局的,則有:
global_tokens=["BigBird","answering"]

#將全局詞元加入到集合中
key_tokens.append(global_tokens)

隨機(jī)詞元: 隨機(jī)選擇一些詞元,這些詞元將通過(guò)關(guān)注其他詞元來(lái)傳輸信息,而那些詞元又可以傳輸信息到其他詞元。這可以降低直接從一個(gè)詞元到另一個(gè)詞元的信息傳輸成本。

#現(xiàn)在,我們可以從句子中隨機(jī)選擇`r`個(gè)詞元。這里,假設(shè)`r`為1,選擇了`is`這個(gè)詞元
>>>random_tokens=["is"]#注意:這個(gè)是完全隨機(jī)選擇的,因此可以是任意詞元。

#將隨機(jī)詞元加入到集合中
key_tokens.append(random_tokens)

#現(xiàn)在看下`key_tokens`集合中有哪些詞元
key_tokens
{'now','is','in','answering','available','BigBird'}

#至此,查詢?cè)~'available'僅關(guān)注集合中的這些詞元,而不用關(guān)心全部

這樣,查詢?cè)~元僅關(guān)注所有詞元的一個(gè)子集,該子集能夠產(chǎn)生完全注意力值的一個(gè)不錯(cuò)的近似。相同的方法將用于所有其他查詢?cè)~元。但請(qǐng)記住,這里的重點(diǎn)是盡可能有效地接近 BERT 的完全注意力。BERT 那種簡(jiǎn)單地讓每個(gè)查詢?cè)~元關(guān)注所有鍵詞元的做法可以建模為一系列矩陣乘法,從而在現(xiàn)代硬件 (如 GPU) 上進(jìn)行高效計(jì)算。然而,滑動(dòng)、全局和隨機(jī)注意力的組合似乎意味著稀疏矩陣乘法,這在現(xiàn)代硬件上很難高效實(shí)現(xiàn)。BigBird 的主要貢獻(xiàn)之一是提出了 塊稀疏 注意力機(jī)制,該機(jī)制可以高效計(jì)算滑動(dòng)、全局和隨機(jī)注意力。我們來(lái)看看吧!

圖解全局、滑動(dòng)、隨機(jī)注意力的概念

首先,我們借助圖來(lái)幫助理解“全局”、“滑動(dòng)”和“隨機(jī)”注意力,并嘗試?yán)斫膺@三種注意力機(jī)制的組合是如何較好地近似標(biāo)準(zhǔn) BERT 類注意力的。

038896b2-8ddf-11ee-939d-92fbcf53809c.png

0393025a-8ddf-11ee-939d-92fbcf53809c.png

03a54b72-8ddf-11ee-939d-92fbcf53809c.png

上圖分別把“全局”(左) 、“滑動(dòng)”(中) 和“隨機(jī)”(右) 連接建模成一個(gè)圖。每個(gè)節(jié)點(diǎn)對(duì)應(yīng)一個(gè)詞元,每條邊代表一個(gè)注意力分?jǐn)?shù)。如果 2 個(gè)詞元之間沒(méi)有邊連接,則其注意力分?jǐn)?shù)為 0。



03aee9ca-8ddf-11ee-939d-92fbcf53809c.gif

03b6d608-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力 是滑動(dòng)連接、全局連接和隨機(jī)連接 (總共 10 個(gè)連接) 的組合,如上圖左側(cè)動(dòng)圖所示。而 完全注意力 圖 (右側(cè)) 則是有全部 15 個(gè)連接 (注意: 總共有 6 個(gè)節(jié)點(diǎn))。你可以簡(jiǎn)單地將完全注意力視為所有詞元都是全局詞元 。

完全注意力: 模型可以直接在單個(gè)層中將信息從一個(gè)詞元傳輸?shù)搅硪粋€(gè)詞元,因?yàn)槊總€(gè)詞元都會(huì)對(duì)每個(gè)其他詞元進(jìn)行查詢,并且受到其他每個(gè)詞元的關(guān)注。我們考慮一個(gè)與上圖類似的例子,如果模型需要將 'going' 與 'now' 關(guān)聯(lián)起來(lái),它可以簡(jiǎn)單地在單層中執(zhí)行此操作,因?yàn)樗鼈儍蓚€(gè)是有直接連接的。

塊稀疏注意力: 如果模型需要在兩個(gè)節(jié)點(diǎn) (或詞元) 之間共享信息,則對(duì)于某些詞元,信息將必須經(jīng)過(guò)路徑中的各個(gè)其他節(jié)點(diǎn); 因?yàn)椴皇撬泄?jié)點(diǎn)都有直接連接的。例如 ,假設(shè)模型需要將 going 與 now 關(guān)聯(lián)起來(lái),那么如果僅存在滑動(dòng)注意力,則這兩個(gè)詞元之間的信息流由路徑 going -> am -> i -> now 來(lái)定義,也就是說(shuō)它必須經(jīng)過(guò) 2 個(gè)其他詞元。因此,我們可能需要多個(gè)層來(lái)捕獲序列的全部信息,而正常的注意力可以在單層中捕捉到這一點(diǎn)。在極端情況下,這可能意味著需要與輸入詞元一樣多的層。然而,如果我們引入一些全局詞元,信息可以通過(guò)以下路徑傳播 going -> i -> now ,這可以幫助縮短路徑。如果我們?cè)倭硗庖腚S機(jī)連接,它就可以通過(guò) going -> am -> now 傳播。借助隨機(jī)連接和全局連接,信息可以非??焖俚?(只需幾層) 從一個(gè)詞元傳輸?shù)较乱粋€(gè)詞元。

如果我們有很多全局詞元,那么我們可能不需要隨機(jī)連接,因?yàn)樾畔⒖梢酝ㄟ^(guò)多個(gè)短路徑傳播。這就是在使用 BigBird 的變體 (稱為 ETC) 時(shí)設(shè)置 num_random_tokens = 0 的動(dòng)機(jī) (稍后部分將會(huì)詳細(xì)介紹)。

在這些圖中,我們假設(shè)注意力矩陣是對(duì)稱的 因?yàn)樵趫D中如果某個(gè)詞元 A 關(guān)注 B,那么 B 也會(huì)關(guān)注 A。從下一節(jié)所示的注意力矩陣圖中可以看出,這個(gè)假設(shè)對(duì)于 BigBird 中的大多數(shù)詞元都成立。

注意力類型 全局次元 滑動(dòng)詞元 隨機(jī)詞元
原始完全注意力 n 0 0
塊稀疏注意力 2 x block_size 3 x block_size num_random_blocks x block_size

原始完全注意力即 BERT 的注意力,而塊稀疏注意力則是 BigBird 的注意力。想知道 block_size 是什么?請(qǐng)繼續(xù)閱讀下文。_現(xiàn)在,為簡(jiǎn)單起見(jiàn),將其視為 1。_

BigBird 塊稀疏注意力

BigBird 塊稀疏注意力是我們上文討論的內(nèi)容的高效實(shí)現(xiàn)。每個(gè)詞元都關(guān)注某些 全局詞元 、 滑動(dòng)詞元隨機(jī)詞元,而不管其他 所有 詞元。作者分別實(shí)現(xiàn)了每類查詢注意力矩陣,并使用了一個(gè)很酷的技巧來(lái)加速 GPU 和 TPU 上的訓(xùn)練/推理。

03ba941e-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力

注意: 在上圖的頂部有 2 個(gè)額外的句子。正如你所注意到的,兩個(gè)句子中的每個(gè)詞元都只是交換了一個(gè)位置。這就是滑動(dòng)注意力的實(shí)現(xiàn)方式。當(dāng) q[i] 與 k[i,0:3] 相乘時(shí),我們會(huì)得到 q[i] 的滑動(dòng)注意力分?jǐn)?shù) (其中i 是序列中元素的索引)。

你可以在 這兒 找到 block_sparse 注意力的具體實(shí)現(xiàn)?,F(xiàn)在看起來(lái)可能非??膳?,但這篇文章肯定會(huì)讓你輕松理解它。

全局注意力

對(duì)于全局注意力而言,每個(gè)查詢?cè)~元關(guān)注序列中的所有其他詞元,并且被其他每個(gè)詞元關(guān)注。我們假設(shè) Vasudev (第一個(gè)詞元) 和 them (最后一個(gè)詞元) 是全局的 (如上圖所示)。你可以看到這些詞元直接連接到所有其他詞元 (藍(lán)色框)。

#偽代碼

Q->Querymartix(seq_length,head_dim)
K->Keymatrix(seq_length,head_dim)

#第一個(gè)和最后一個(gè)詞元關(guān)注所有其他詞元
Q[0]x[K[0],K[1],K[2],......,K[n-1]]
Q[n-1]x[K[0],K[1],K[2],......,K[n-1]]

#第一個(gè)和最后一個(gè)詞元也被其他所有詞元關(guān)注
K[0]x[Q[0],Q[1],Q[2],......,Q[n-1]]
K[n-1]x[Q[0],Q[1],Q[2],......,Q[n-1]]

滑動(dòng)注意力

鍵詞元序列被復(fù)制兩次,其中一份每個(gè)詞元向右移動(dòng)一步,另一份每個(gè)詞元向左移動(dòng)一步。現(xiàn)在,如果我們將查詢序列向量乘以這 3 個(gè)序列向量,我們將覆蓋所有滑動(dòng)詞元。計(jì)算復(fù)雜度就是 O(3n) = O(n) 。參考上圖,橙色框代表滑動(dòng)注意力。你可以在圖的頂部看到 3 個(gè)序列,其中 2 個(gè)序列各移動(dòng)了一個(gè)詞元 (1 個(gè)向左,1 個(gè)向右)。

#我們想做的
Q[i]x[K[i-1],K[i],K[i+1]]fori=1:-1

#高效的代碼實(shí)現(xiàn)(乘法為點(diǎn)乘)
[Q[0],Q[1],Q[2],......,Q[n-2],Q[n-1]]x[K[1],K[2],K[3],......,K[n-1],K[0]]
[Q[0],Q[1],Q[2],......,Q[n-1]]x[K[n-1],K[0],K[1],......,K[n-2]]
[Q[0],Q[1],Q[2],......,Q[n-1]]x[K[0],K[1],K[2],......,K[n-1]]

#每個(gè)序列被乘3詞,即`window_size=3`。為示意,僅列出主要計(jì)算,省略了一些計(jì)算。

隨機(jī)注意力

隨機(jī)注意力確保每個(gè)查詢?cè)~元也會(huì)關(guān)注一些隨機(jī)詞元。對(duì)實(shí)現(xiàn)而言,這意味著模型隨機(jī)選取一些詞元并計(jì)算它們的注意力分?jǐn)?shù)。

#r1,r2,r為隨機(jī)索引;注意r1,r2,r每行取值不同
Q[1]x[Q[r1],Q[r2],......,Q[r]]
.
.
.
Q[n-2]x[Q[r1],Q[r2],......,Q[r]]

#不用管第0個(gè)和第n-1個(gè)詞元,因?yàn)樗鼈円呀?jīng)是全局詞元了。

注意: 當(dāng)前的實(shí)現(xiàn)進(jìn)一步將序列劃分為塊,并且每個(gè)符號(hào)都依塊而定義而非依詞元而定義。我們?cè)谙乱还?jié)中會(huì)更詳細(xì)地討論這個(gè)問(wèn)題。

實(shí)現(xiàn)

回顧: 在常規(guī) BERT 注意力中,一系列詞元,即 通過(guò)線性層投影到 ,并基于它們計(jì)算注意力分?jǐn)?shù) ,公式為 。使用 BigBird 塊稀疏注意力時(shí),我們使用相同的算法,但僅針對(duì)一些選定的查詢和鍵向量進(jìn)行計(jì)算。

我們來(lái)看看 BigBird 塊稀疏注意力是如何實(shí)現(xiàn)的。首先,我們用 分別代表 block_size 、num_random_blocks 、num_sliding_blocks 、num_global_blocks 。我們以 為例來(lái)說(shuō)明 BigBird 塊稀疏注意力的機(jī)制部分,如下所示:

03d27980-8ddf-11ee-939d-92fbcf53809c.png

的注意力分?jǐn)?shù)分別計(jì)算如下:

的注意力分?jǐn)?shù)由 表示,其中 ,即為第一塊中的所有詞元與序列中的所有其他詞元之間的注意力分?jǐn)?shù)。

03e72f06-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力

表示第 1 塊, 表示第 塊。我們僅在 和 (即所有鍵) 之間執(zhí)行正常的注意力操作。

為了計(jì)算第二塊中詞元的注意力分?jǐn)?shù),我們收集前三塊、最后一塊和第五塊。然后我們可以計(jì)算。

03fe1d56-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力

這里,我用 表示詞元只是為了明確地表示它們的性質(zhì) (即是全局、隨機(jī)還是滑動(dòng)詞元),只用 無(wú)法表示他們各自的性質(zhì)。

為了計(jì)算 的注意力分?jǐn)?shù),我們先收集相應(yīng)的全局、滑動(dòng)、隨機(jī)鍵向量,并基于它們正常計(jì)算 上的注意力。請(qǐng)注意,正如前面滑動(dòng)注意力部分所討論的,滑動(dòng)鍵是使用特殊的移位技巧來(lái)收集的。

041c25bc-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力

為了計(jì)算倒數(shù)第二塊 (即 ) 中詞元的注意力分?jǐn)?shù),我們收集第一塊、最后三塊和第三塊的鍵向量。然后我們用公式進(jìn)行計(jì)算。這和計(jì)算 非常相似。

04341046-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力

最后一塊 的注意力分?jǐn)?shù)由 表示,其中,只不過(guò)是最后一塊中的所有詞元與序列中的所有其他詞元之間的注意力分?jǐn)?shù)。這與我們對(duì) 所做的非常相似。

043f0c6c-8ddf-11ee-939d-92fbcf53809c.png

BigBird 塊稀疏注意力

我們將上面的矩陣組合起來(lái)得到最終的注意力矩陣。該注意力矩陣可用于獲取所有詞元的表征。

045211c2-8ddf-11ee-939d-92fbcf53809c.gif

BigBird 塊稀疏注意力

上圖中 藍(lán)色 -> 全局塊 、紅色 -> 隨機(jī)塊 、橙色 -> 滑動(dòng)塊 。在前向傳播過(guò)程中,我們不存儲(chǔ)“白色”塊,而是直接為每個(gè)單獨(dú)的部分計(jì)算加權(quán)值矩陣 (即每個(gè)詞元的表示),如上所述。

現(xiàn)在,我們已經(jīng)介紹了塊稀疏注意力最難的部分,即它的實(shí)現(xiàn)。希望對(duì)你更好地理解實(shí)際代碼有幫助?,F(xiàn)在你可以深入研究代碼了,在此過(guò)程中你可以將代碼的每個(gè)部分與上面的某個(gè)部分聯(lián)系起來(lái)以助于理解。

時(shí)間和內(nèi)存復(fù)雜度

注意力類型 序列長(zhǎng)度 時(shí)間和內(nèi)存復(fù)雜度
原始完全注意力 512 T
1024 4 x T
4096 64 x T
塊稀疏注意力 1024 2 x T
4096 8 x T

BERT 注意力和 BigBird 塊稀疏注意力的時(shí)間和空間復(fù)雜度之比較。

展開(kāi)以了解復(fù)雜度的計(jì)算過(guò)程。

BigBird時(shí)間復(fù)雜度=O(wxn+rxn+gxn)
BERT時(shí)間復(fù)雜度=O(n^2)

假設(shè):
w=3x64
r=3x64
g=2x64

當(dāng)序列長(zhǎng)度為512時(shí)
=>**BERT時(shí)間復(fù)雜度=512^2**

當(dāng)序列長(zhǎng)度為1024時(shí)
=>BERT時(shí)間復(fù)雜度=(2x512)^2
=>**BERT時(shí)間復(fù)雜度=4x512^2**

=>BigBird時(shí)間復(fù)雜度=(8x64)x(2x512)
=>**BigBird時(shí)間復(fù)雜度=2x512^2**

當(dāng)序列長(zhǎng)度為4096時(shí)
=>BERT時(shí)間復(fù)雜度=(8x512)^2
=>**BERT時(shí)間復(fù)雜度=64x512^2**

=>BigBird時(shí)間復(fù)雜度=(8x64)x(8x512)
=>BigBird時(shí)間復(fù)雜度=8x(512x512)
=>**BigBird時(shí)間復(fù)雜度=8x512^2**

ITC 與 ETC

BigBird 模型可以使用 2 種不同的策略進(jìn)行訓(xùn)練: ITCETC。 ITC (internal transformer construction,內(nèi)部 transformer 構(gòu)建) 就是我們上面討論的。在 ETC (extended transformer construction,擴(kuò)展 transformer 構(gòu)建) 中,會(huì)有更多的全局詞元,以便它們關(guān)注所有詞元或者被所有詞元關(guān)注。

ITC 需要的計(jì)算量較小,因?yàn)楹苌儆性~元是全局的,同時(shí)模型可以捕獲足夠的全局信息 (也可以借助隨機(jī)注意力)。而 ETC 對(duì)于需要大量全局詞元的任務(wù)非常有幫助,例如對(duì) 問(wèn)答 類任務(wù)而言,整個(gè)問(wèn)題應(yīng)該被所有上下文關(guān)注,以便能夠?qū)⑸舷挛恼_地與問(wèn)題相關(guān)聯(lián)。

注意: BigBird 論文顯示,在很多 ETC 實(shí)驗(yàn)中,隨機(jī)塊的數(shù)量設(shè)置為 0??紤]到我們上文圖解部分的討論,這是合理的。

下表總結(jié)了 ITC 和 ETC:

ITC ETC
全局注意力的注意力矩陣

046aea8a-8ddf-11ee-939d-92fbcf53809c.png

047749ec-8ddf-11ee-939d-92fbcf53809c.png

全局詞元 2 x block_size extra_tokens + 2 x block_size
隨機(jī)詞元 num_random_blocks x block_size num_random_blocks x block_size
滑動(dòng)詞元 3 x block_size 3 x block_size

在 Transformers 中使用 BigBird

你可以像使用任何其他 模型一樣使用 BigBirdModel 。我們看一下代碼:

fromtransformersimportBigBirdModel

#從預(yù)訓(xùn)練checkpoint中加載bigbird模型
model=BigBirdModel.from_pretrained("google/bigbird-roberta-base")
#使用默認(rèn)配置初始化模型,如attention_type="block_sparse",num_random_blocks=3,block_size=64
#你也可以按照自己的需要改變這些參數(shù)。這3個(gè)參數(shù)只改變每個(gè)查詢?cè)~元關(guān)注的詞元數(shù)。
model=BigBirdModel.from_pretrained("google/bigbird-roberta-base",num_random_blocks=2,block_size=16)

#通過(guò)把a(bǔ)ttention_type設(shè)成`original_full`,BigBird就會(huì)用復(fù)雜度為n^2的完全注意力。此時(shí),BigBird與BERT相似度為99.9%。
model=BigBirdModel.from_pretrained("google/bigbird-roberta-base",attention_type="original_full")

截至現(xiàn)在, Hub 中總共有 3 個(gè) BigBird checkpoint: bigbird-roberta-base,bigbird-roberta-large 以及 bigbird-base-trivia-itc。前兩個(gè)檢查點(diǎn)是使用 masked_lm 損失 預(yù)訓(xùn)練 BigBirdForPretraining 而得; 而最后一個(gè)是在 trivia-qa 數(shù)據(jù)集上微調(diào) BigBirdForQuestionAnswering 而得。

讓我們看一下如果用你自己喜歡的 PyTorch 訓(xùn)練器,最少需要多少代碼就可以使用 的 BigBird 模型來(lái)微調(diào)你自己的任務(wù)。

#以問(wèn)答任務(wù)為例
fromtransformersimportBigBirdForQuestionAnswering,BigBirdTokenizer
importtorch

device=torch.device("cpu")
iftorch.cuda.is_available():
device=torch.device("cuda")

#我們用預(yù)訓(xùn)練權(quán)重初始化bigbird模型,并隨機(jī)初始化其頭分類器
model=BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base",block_size=64,num_random_blocks=3)
tokenizer=BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
model.to(device)

dataset="torch.utils.data.DataLoaderobject"
optimizer="torch.optimobject"
epochs=...

#最簡(jiǎn)訓(xùn)練循環(huán)
foreinrange(epochs):
forbatchindataset:
model.train()
batch={k:batch[k].to(device)forkinbatch}

#前向
output=model(**batch)

#后向
output["loss"].backward()
optimizer.step()
optimizer.zero_grad()

#將最終權(quán)重存至本地目錄
model.save_pretrained("")

#將權(quán)重推到Hub中
fromhuggingface_hubimportModelHubMixin
ModelHubMixin.push_to_hub("",model_id="")

#使用微調(diào)后的模型,以用于推理
question=["Howareyoudoing?","Howislifegoing?"]
context=["",""]
batch=tokenizer(question,context,return_tensors="pt")
batch={k:batch[k].to(device)forkinbatch}

model=BigBirdForQuestionAnswering.from_pretrained("")
model.to(device)
withtorch.no_grad():
start_logits,end_logits=model(**batch).to_tuple()
#這里,你可以使用自己的策略對(duì)start_logits,end_logits進(jìn)行解碼

#注意:
#該代碼段僅用于展示即使你想用自己的PyTorch訓(xùn)練器微調(diào)BigBrid,這也是相當(dāng)容易的。
#我會(huì)建議使用Trainer,它更簡(jiǎn)單,功能也更多。

使用 BigBird 時(shí),需要記住以下幾點(diǎn):

序列長(zhǎng)度必須是塊大小的倍數(shù),即 seqlen % block_size = 0 。你不必?fù)?dān)心,因?yàn)槿绻?batch 的序列長(zhǎng)度不是 block_size 的倍數(shù), transformers 會(huì)自動(dòng)填充至最近的整數(shù)倍。

目前,Hugging Face 的實(shí)現(xiàn) 尚不支持 ETC,因此只有第一個(gè)和最后一個(gè)塊是全局的。

當(dāng)前實(shí)現(xiàn)不支持 num_random_blocks = 0 。

論文作者建議當(dāng)序列長(zhǎng)度 < 1024 時(shí)設(shè)置 attention_type = "original_full" 。

必須滿足: seq_length > global_token + random_tokens + moving_tokens + buffer_tokens ,其中 global_tokens = 2 x block_size 、 sliding_tokens = 3 x block_size 、 random_tokens = num_random_blocks x block_size 且 buffer_tokens = num_random_blocks x block_size 。如果你不能滿足這一點(diǎn), transformers 會(huì)自動(dòng)將 attention_type 切換為 original_full 并告警。

當(dāng)使用 BigBird 作為解碼器 (或使用 BigBirdForCasualLM ) 時(shí), attention_type 應(yīng)該是 original_full 。但你不用擔(dān)心, transformers 會(huì)自動(dòng)將 attention_type 切換為 original_full ,以防你忘記這樣做。

下一步

@patrickvonplaten 建了一個(gè)非常酷的 筆記本,以展示如何在 trivia-qa 數(shù)據(jù)集上評(píng)估 BigBirdForQuestionAnswering 。你可以隨意用這個(gè)筆記本來(lái)玩玩 BigBird。

審核編輯:黃飛

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • gpu
    gpu
    +關(guān)注

    關(guān)注

    28

    文章

    4740

    瀏覽量

    128945
  • 算法
    +關(guān)注

    關(guān)注

    23

    文章

    4612

    瀏覽量

    92890
  • Transformer
    +關(guān)注

    關(guān)注

    0

    文章

    143

    瀏覽量

    6006
  • nlp
    nlp
    +關(guān)注

    關(guān)注

    1

    文章

    488

    瀏覽量

    22037

原文標(biāo)題:深入理解BigBird的塊稀疏注意力

文章出處:【微信號(hào):OSC開(kāi)源社區(qū),微信公眾號(hào):OSC開(kāi)源社區(qū)】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    深入理解SD卡原理和其內(nèi)部結(jié)構(gòu)總結(jié)

    深入理解SD卡原理和其內(nèi)部結(jié)構(gòu)總結(jié)
    發(fā)表于 08-18 11:11

    深入理解Android

    深入理解Android
    發(fā)表于 08-20 15:30

    深入理解實(shí)現(xiàn)RTOS_連載

    和trcohili的帖子。深入理解實(shí)現(xiàn)RTOS_連載1_RTOS的前生今世今天發(fā)布的是第一篇,"RTOS的前生今世"。通過(guò)軟件系統(tǒng)結(jié)構(gòu)的比對(duì)簡(jiǎn)要的介紹rtos為何而生。如果讀者對(duì)RTOS
    發(fā)表于 05-29 11:20

    深入理解實(shí)現(xiàn)RTOS_連載

    和trcohili的帖子。trochili rtos完全是作者興趣所在,且行且堅(jiān)持,比沒(méi)有duo。深入理解實(shí)現(xiàn)RTOS_連載1_RTOS的前生今世今天發(fā)布的是第一篇,"RTOS的前生今世"
    發(fā)表于 05-30 01:02

    深入理解lte-a

    深入理解LTE-A
    發(fā)表于 02-26 10:21

    如何深入理解ES6之函數(shù)

    深入理解ES6之函數(shù)
    發(fā)表于 05-22 07:40

    深入理解STM32

    時(shí)鐘系統(tǒng)是處理器的核心,所以在學(xué)習(xí)STM32所有外設(shè)之前,認(rèn)真學(xué)習(xí)時(shí)鐘系統(tǒng)是必要的,有助于深入理解STM32。下面是從網(wǎng)上找的一個(gè)STM32時(shí)鐘框圖,比《STM32中文參考手冊(cè)》里面的是中途看起來(lái)清晰一些:重要的時(shí)鐘:PLLCLK,SYSCLK,HCKL,PCLK1,...
    發(fā)表于 08-12 07:46

    對(duì)棧的深入理解

    為什么要深入理解棧?做C語(yǔ)言開(kāi)發(fā)如果棧設(shè)置不合理或者使用不對(duì),棧就會(huì)溢出,溢出就會(huì)遇到無(wú)法預(yù)測(cè)亂飛現(xiàn)象。所以對(duì)棧的深入理解是非常重要的。注:動(dòng)畫(huà)如果看不清楚可以電腦看更清晰啥是棧先來(lái)看一段動(dòng)畫(huà):沒(méi)有
    發(fā)表于 02-15 07:01

    為什么要深入理解

    [導(dǎo)讀] 從這篇文章開(kāi)始,將會(huì)不定期更新關(guān)于嵌入式C語(yǔ)言編程相關(guān)的個(gè)人認(rèn)為比較重要的知識(shí)點(diǎn),或者踩過(guò)的坑。為什么要深入理解棧?做C語(yǔ)言開(kāi)發(fā)如果棧設(shè)置不合理或者使用不對(duì),棧就會(huì)溢出,溢出就會(huì)遇到無(wú)法
    發(fā)表于 02-15 06:09

    深入理解Android之資源文件

    深入理解Android之資源文件
    發(fā)表于 01-22 21:11 ?22次下載

    深入理解Android》文前

    深入理解Android》文前
    發(fā)表于 03-19 11:23 ?0次下載

    深入理解Android:卷I》

    深入理解Android:卷I》
    發(fā)表于 03-19 11:23 ?0次下載

    深入理解Android網(wǎng)絡(luò)編程

    深入理解Android網(wǎng)絡(luò)編程
    發(fā)表于 03-19 11:26 ?1次下載

    深入理解MOS管電子版資源下載

    深入理解MOS管電子版資源下載
    發(fā)表于 07-09 09:43 ?0次下載

    華為開(kāi)發(fā)者大會(huì)2021:深入理解用戶意圖

     如何深入理解用戶意圖,實(shí)現(xiàn)服務(wù)精準(zhǔn)分發(fā)。
    的頭像 發(fā)表于 10-22 15:41 ?1862次閱讀
    華為開(kāi)發(fā)者大會(huì)2021:<b class='flag-5'>深入理解</b>用戶意圖