transformer結(jié)構(gòu)是google在17年的Attention Is All You Need論文中提出,在NLP的多個(gè)任務(wù)上取得了非常好的效果,可以說(shuō)目前NLP發(fā)展都離不開(kāi)transformer。最大特點(diǎn)是拋棄了傳統(tǒng)的CNN和RNN,整個(gè)網(wǎng)絡(luò)結(jié)構(gòu)完全是由Attention機(jī)制組成。由于其出色性能以及對(duì)下游任務(wù)的友好性或者說(shuō)下游任務(wù)僅僅微調(diào)即可得到不錯(cuò)效果,在計(jì)算機(jī)視覺(jué)領(lǐng)域不斷有人嘗試將transformer引入,近期也出現(xiàn)了一些效果不錯(cuò)的嘗試,典型的如目標(biāo)檢測(cè)領(lǐng)域的detr和可變形detr,分類領(lǐng)域的vision transformer等等。本文從transformer結(jié)構(gòu)出發(fā),結(jié)合視覺(jué)中的transformer成果(具體是vision transformer和detr)進(jìn)行分析,希望能夠幫助cv領(lǐng)域想了解transformer的初學(xué)者快速入門(mén)。由于本人接觸transformer時(shí)間也不長(zhǎng),也算初學(xué)者,故如果有描述或者理解錯(cuò)誤的地方歡迎指正。
本文的大部分圖來(lái)自論文、國(guó)外博客和國(guó)內(nèi)翻譯博客,在此一并感謝前人工作,具體鏈接見(jiàn)參考資料。本文特別長(zhǎng),大概有3w字,請(qǐng)先點(diǎn)贊收藏然后慢慢看....
1 transformer介紹
一般講解transformer都會(huì)以機(jī)器翻譯任務(wù)為例子講解,機(jī)器翻譯任務(wù)是指將一種語(yǔ)言轉(zhuǎn)換得到另一種語(yǔ)言,例如英語(yǔ)翻譯為中文任務(wù)。從最上層來(lái)看,如下所示:
1.1 早期seq2seq
機(jī)器翻譯是一個(gè)歷史悠久的問(wèn)題,本質(zhì)可以理解為序列轉(zhuǎn)序列問(wèn)題,也就是我們常說(shuō)的seq2seq結(jié)構(gòu),也可以稱為encoder-decoder結(jié)構(gòu),如下所示:
encoder和decoder在早期一般是RNN模塊(因?yàn)槠淇梢圆东@時(shí)序信息),后來(lái)引入了LSTM或者GRU模塊,不管內(nèi)部組件是啥,其核心思想都是通過(guò)Encoder編碼成一個(gè)表示向量,即上下文編碼向量,然后交給Decoder來(lái)進(jìn)行解碼,翻譯成目標(biāo)語(yǔ)言。一個(gè)采用典型RNN進(jìn)行編碼碼翻譯的可視化圖如下:
可以看出,其解碼過(guò)程是順序進(jìn)行,每次僅解碼出一個(gè)單詞。對(duì)于CV領(lǐng)域初學(xué)者來(lái)說(shuō),RNN模塊構(gòu)建的seq2seq算法,理解到這個(gè)程度就可以了,不需要深入探討如何進(jìn)行訓(xùn)練。但是上述結(jié)構(gòu)其實(shí)有缺陷,具體來(lái)說(shuō)是:
不論輸入和輸出的語(yǔ)句長(zhǎng)度是什么,中間的上下文向量長(zhǎng)度都是固定的,一旦長(zhǎng)度過(guò)長(zhǎng),僅僅靠一個(gè)固定長(zhǎng)度的上下文向量明顯不合理
僅僅利用上下文向量解碼,會(huì)有信息瓶頸,長(zhǎng)度過(guò)長(zhǎng)時(shí)候信息可能會(huì)丟失
通俗理解是編碼器與解碼器的連接點(diǎn)僅僅是編碼單元輸出的隱含向量,其包含的信息有限,對(duì)于一些復(fù)雜任務(wù)可能信息不夠,如要翻譯的句子較長(zhǎng)時(shí),一個(gè)上下文向量可能存不下那么多信息,就會(huì)造成翻譯精度的下降。
1.2 基于attention的seq2seq
基于上述缺陷進(jìn)而提出帶有注意力機(jī)制Attention的seq2seq,同樣可以應(yīng)用于RNN、LSTM或者GRU模塊中。注意力機(jī)制Attention對(duì)人類來(lái)說(shuō)非常好理解,假設(shè)給定一張圖片,我們會(huì)自動(dòng)聚焦到一些關(guān)鍵信息位置,而不需要逐行掃描全圖。此處的attention也是同一個(gè)意思,其本質(zhì)是對(duì)輸入的自適應(yīng)加權(quán),結(jié)合cv領(lǐng)域的senet中的se模塊就能夠理解了。
se模塊最終是學(xué)習(xí)出一個(gè)1x1xc的向量,然后逐通道乘以原始輸入,從而對(duì)特征圖的每個(gè)通道進(jìn)行加權(quán)即通道注意力,對(duì)attention進(jìn)行抽象,不管啥領(lǐng)域其機(jī)制都可以歸納為下圖:
將Query(通常是向量)和4個(gè)Key(和Q長(zhǎng)度相同的向量)分別計(jì)算相似性,然后經(jīng)過(guò)softmax得到q和4個(gè)key相似性的概率權(quán)重分布,然后對(duì)應(yīng)權(quán)重乘以Value(和Q長(zhǎng)度相同的向量),最后相加即可得到包含注意力的attention值輸出,理解上應(yīng)該不難。舉個(gè)簡(jiǎn)單例子說(shuō)明:
假設(shè)世界上所有小吃都可以被標(biāo)簽化,例如微辣、特辣、變態(tài)辣、微甜、有嚼勁....,總共有1000個(gè)標(biāo)簽,現(xiàn)在我想要吃的小吃是[微辣、微甜、有嚼勁],這三個(gè)單詞就是我的Query
來(lái)到東門(mén)老街一共100家小吃點(diǎn),每個(gè)店鋪賣(mài)的東西不一樣,但是肯定可以被標(biāo)簽化,例如第一家小吃被標(biāo)簽化后是[微辣、微咸],第二家小吃被標(biāo)簽化后是[特辣、微臭、特咸],第二家小吃被標(biāo)簽化后是[特辣、微甜、特咸、有嚼勁],其余店鋪都可以被標(biāo)簽化,每個(gè)店鋪的標(biāo)簽就是Keys,但是每家店鋪由于賣(mài)的東西不一樣,單品種類也不一樣,所以被標(biāo)簽化后每一家的標(biāo)簽List不一樣長(zhǎng)
Values就是每家店鋪對(duì)應(yīng)的單品,例如第一家小吃的Values是[烤羊肉串、炒花生]
將Query和所有的Keys進(jìn)行一一比對(duì),相當(dāng)于計(jì)算相似性,此時(shí)就可以知道我想買(mǎi)的小吃和每一家店鋪的匹配情況,最后有了匹配列表,就可以去店鋪里面買(mǎi)東西了(Values和相似性加權(quán)求和)。最終的情況可能是,我在第一家店鋪買(mǎi)了烤羊肉串,然后在第10家店鋪買(mǎi)了個(gè)玉米,最后在第15家店鋪買(mǎi)了個(gè)烤面筋
以上就是完整的注意力機(jī)制,采用我心中的標(biāo)準(zhǔn)Query去和被標(biāo)簽化的所有店鋪Keys一一比對(duì),此時(shí)就可以得到我的Query在每個(gè)店鋪中的匹配情況,最終去不同店鋪買(mǎi)不同東西的過(guò)程就是權(quán)重和Values加權(quán)求和過(guò)程。簡(jiǎn)要代碼如下:
# 假設(shè)q是(1,N,512),N就是最大標(biāo)簽化后的list長(zhǎng)度,k是(1,M,512),M可以等于N,也可以不相等
# (1,N,512) x (1,512,M)-->(1,N,M)
attn = torch.matmul(q, k.transpose(2, 3))
# softmax轉(zhuǎn)化為概率,輸出(1,N,M),表示q中每個(gè)n和每個(gè)m的相關(guān)性
attn=F.softmax(attn, dim=-1)
# (1,N,M) x (1,M,512)-->(1,N,512),V和k的shape相同
output = torch.matmul(attn, v)
帶有attention的RNN模塊組成的ser2seq,解碼時(shí)候可視化如下:
在沒(méi)有attention時(shí)候,不同解碼階段都僅僅利用了同一個(gè)編碼層的最后一個(gè)隱含輸出,加入attention后可以通過(guò)在每個(gè)解碼時(shí)間步輸入的都是不同的上下文向量,以上圖為例,解碼階段會(huì)將第一個(gè)開(kāi)啟解碼標(biāo)志(也就是Q)與編碼器的每一個(gè)時(shí)間步的隱含狀態(tài)(一系列Key和Value)進(jìn)行點(diǎn)乘計(jì)算相似性得到每一時(shí)間步的相似性分?jǐn)?shù),然后通過(guò)softmax轉(zhuǎn)化為概率分布,然后將概率分布和對(duì)應(yīng)位置向量進(jìn)行加權(quán)求和得到新的上下文向量,最后輸入解碼器中進(jìn)行解碼輸出,其詳細(xì)解碼可視化如下:
通過(guò)上述簡(jiǎn)單的attention引入,可以將機(jī)器翻譯性能大幅提升,引入attention有以下幾個(gè)好處:
注意力顯著提高了機(jī)器翻譯性能
注意力允許解碼器以不同程度的權(quán)重利用到編碼器的所有信息,可以繞過(guò)瓶頸
通過(guò)檢查注意力分布,可以看到解碼器在關(guān)注什么,可解釋性強(qiáng)
1.3 基于transformer的seq2seq
基于attention的seq2seq的結(jié)構(gòu)雖然說(shuō)解決了很多問(wèn)題,但是其依然存在不足:
不管是采用RNN、LSTM還是GRU都不利于并行訓(xùn)練和推理,因?yàn)橄嚓P(guān)算法只能從左向右依次計(jì)算或者從右向左依次計(jì)算
長(zhǎng)依賴信息丟失問(wèn)題,順序計(jì)算過(guò)程中信息會(huì)丟失,雖然LSTM號(hào)稱有緩解,但是無(wú)法徹底解決
最大問(wèn)題應(yīng)該是無(wú)法并行訓(xùn)練,不利于大規(guī)??焖儆?xùn)練和部署,也不利于整個(gè)算法領(lǐng)域發(fā)展,故在Attention Is All You Need論文中拋棄了傳統(tǒng)的CNN和RNN,將attention機(jī)制發(fā)揮到底,整個(gè)網(wǎng)絡(luò)結(jié)構(gòu)完全是由Attention機(jī)制組成,這是一個(gè)比較大的進(jìn)步。
google所提基于transformer的seq2seq整體結(jié)構(gòu)如下所示:
其包括6個(gè)結(jié)構(gòu)完全相同的編碼器,和6個(gè)結(jié)構(gòu)完全相同的解碼器,其中每個(gè)編碼器和解碼器設(shè)計(jì)思想完全相同,只不過(guò)由于任務(wù)不同而有些許區(qū)別,整體詳細(xì)結(jié)構(gòu)如下所示:
第一眼看有點(diǎn)復(fù)雜,其中N=6,由于基于transformer的翻譯任務(wù)已經(jīng)轉(zhuǎn)化為分類任務(wù)(目標(biāo)翻譯句子有多長(zhǎng),那么就有多少個(gè)分類樣本),故在解碼器最后會(huì)引入fc+softmax層進(jìn)行概率輸出,訓(xùn)練也比較簡(jiǎn)單,直接采用ce loss即可,對(duì)于采用大量數(shù)據(jù)訓(xùn)練好的預(yù)訓(xùn)練模型,下游任務(wù)僅僅需要訓(xùn)練fc層即可。上述結(jié)構(gòu)看起來(lái)有點(diǎn)復(fù)雜,一個(gè)稍微抽象點(diǎn)的圖示如下:
看起來(lái)比基于RNN或者其余結(jié)構(gòu)構(gòu)建的seq2seq簡(jiǎn)單很多。下面結(jié)合代碼和原理進(jìn)行深入分析。
1.4 transformer深入分析
前面寫(xiě)了一大堆,沒(méi)有理解沒(méi)有關(guān)系,對(duì)于cv初學(xué)者來(lái)說(shuō)其實(shí)只需要理解QKV的含義和注意力機(jī)制的三個(gè)計(jì)算步驟:Q和所有K計(jì)算相似性;對(duì)相似性采用softmax轉(zhuǎn)化為概率分布;將概率分布和V進(jìn)行一一對(duì)應(yīng)相乘,最后相加得到新的和Q一樣長(zhǎng)的向量輸出即可,重點(diǎn)是下面要講的transformer結(jié)構(gòu)。
下面按照?編碼器輸入數(shù)據(jù)處理->編碼器運(yùn)行->解碼器輸入數(shù)據(jù)處理->解碼器運(yùn)行->分類head?的實(shí)際運(yùn)行流程進(jìn)行講解。
1.4.1 編碼器輸入數(shù)據(jù)處理
(1) 源單詞嵌入
以上面翻譯任務(wù)為例,原始待翻譯輸入是三個(gè)單詞:
輸入是三個(gè)單詞,為了能夠?qū)⑽谋緝?nèi)容輸入到網(wǎng)絡(luò)中肯定需要進(jìn)行向量化(不然單詞如何計(jì)算?),具體是采用nlp領(lǐng)域的embedding算法進(jìn)行詞嵌入,也就是常說(shuō)的Word2Vec。對(duì)于cv來(lái)說(shuō)知道是干嘛的就行,不必了解細(xì)節(jié)。假設(shè)每個(gè)單詞都可以嵌入成512個(gè)長(zhǎng)度的向量,故此時(shí)輸入即為3x512,注意Word2Vec操作只會(huì)輸入到第一個(gè)編碼器中,后面的編碼器接受的輸入是前一個(gè)編碼器輸出。
為了便于組成batch(不同訓(xùn)練句子單詞個(gè)數(shù)肯定不一樣)進(jìn)行訓(xùn)練,可以簡(jiǎn)單統(tǒng)計(jì)所有訓(xùn)練句子的單詞個(gè)數(shù),取最大即可,假設(shè)統(tǒng)計(jì)后發(fā)現(xiàn)待翻譯句子最長(zhǎng)是10個(gè)單詞,那么編碼器輸入是10x512,額外填充的512維向量可以采用固定的標(biāo)志編碼得到,例如$$。
(2) 位置編碼positional encoding
采用經(jīng)過(guò)單詞嵌入后的向量輸入到編碼器中還不夠,因?yàn)閠ransformer內(nèi)部沒(méi)有類似RNN的循環(huán)結(jié)構(gòu),沒(méi)有捕捉順序序列的能力,或者說(shuō)無(wú)論句子結(jié)構(gòu)怎么打亂,transformer都會(huì)得到類似的結(jié)果。為了解決這個(gè)問(wèn)題,在編碼詞向量時(shí)會(huì)額外引入了位置編碼position encoding向量表示兩個(gè)單詞i和j之間的距離,簡(jiǎn)單來(lái)說(shuō)就是在詞向量中加入了單詞的位置信息。
加入位置信息的方式非常多,最簡(jiǎn)單的可以是直接將絕對(duì)坐標(biāo)0,1,2編碼成512個(gè)長(zhǎng)度向量即可。作者實(shí)際上提出了兩種方式:
網(wǎng)絡(luò)自動(dòng)學(xué)習(xí)
自己定義規(guī)則
提前假設(shè)單詞嵌入并且組成batch后,shape為(b,N,512),N是序列最大長(zhǎng)度,512是每個(gè)單詞的嵌入向量長(zhǎng)度,b是batch
(a) 網(wǎng)絡(luò)自動(dòng)學(xué)習(xí)
self.pos_embedding = nn.Parameter(torch.randn(1, N, 512))
比較簡(jiǎn)單,因?yàn)槲恢镁幋a向量需要和輸入嵌入(b,N,512)相加,所以其shape為(1,N,512)表示N個(gè)位置,每個(gè)位置采用512長(zhǎng)度向量進(jìn)行編碼
(b) 自己定義規(guī)則
自定義規(guī)則做法非常多,論文中采用的是sin-cos規(guī)則,具體做法是:
將向量(N,512)采用如下函數(shù)進(jìn)行處理
pos即0~N,i是0-511
將向量的512維度切分為奇數(shù)行和偶數(shù)行
偶數(shù)行采用sin函數(shù)編碼,奇數(shù)行采用cos函數(shù)編碼
然后按照原始行號(hào)拼接
def get_position_angle_vec(position):
# d_hid是0-511,position表示單詞位置0~N-1
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
# 每個(gè)單詞位置0~N-1都可以編碼得到512長(zhǎng)度的向量
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
# 偶數(shù)列進(jìn)行sin
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
# 奇數(shù)列進(jìn)行cos
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
上面例子的可視化如下:
如此編碼的優(yōu)點(diǎn)是能夠擴(kuò)展到未知的序列長(zhǎng)度,例如前向時(shí)候有特別長(zhǎng)的句子,其可視化如下:
作者為啥要設(shè)計(jì)如此復(fù)雜的編碼規(guī)則?原因是sin和cos的如下特性:
可以將用進(jìn)行線性表出:
假設(shè)k=1,那么下一個(gè)位置的編碼向量可以由前面的編碼向量線性表示,等價(jià)于以一種非常容易學(xué)會(huì)的方式告訴了網(wǎng)絡(luò)單詞之間的絕對(duì)位置,讓模型能夠輕松學(xué)習(xí)到相對(duì)位置信息。注意編碼方式不是唯一的,將單詞嵌入向量和位置編碼向量相加就可以得到編碼器的真正輸入了,其輸出shape是(b,N,512)。
1.4.2 編碼器前向過(guò)程
編碼器由兩部分組成:自注意力層和前饋神經(jīng)網(wǎng)絡(luò)層。
其前向可視化如下:
注意上圖沒(méi)有繪制出單詞嵌入向量和位置編碼向量相加過(guò)程,但是是存在的。
(1) 自注意力層
通過(guò)前面分析我們知道自注意力層其實(shí)就是attention操作,并且由于其QKV來(lái)自同一個(gè)輸入,故稱為自注意力層。我想大家應(yīng)該能想到這里attention層作用,在參考資料1博客里面舉了個(gè)簡(jiǎn)單例子來(lái)說(shuō)明attention的作用:假設(shè)我們想要翻譯的輸入句子為T(mén)he animal didn't cross the street because it was too tired,這個(gè)“it”在這個(gè)句子是指什么呢?它指的是street還是這個(gè)animal呢?這對(duì)于人類來(lái)說(shuō)是一個(gè)簡(jiǎn)單的問(wèn)題,但是對(duì)于算法則不是。當(dāng)模型處理這個(gè)單詞“it”的時(shí)候,自注意力機(jī)制會(huì)允許“it”與“animal”建立聯(lián)系即隨著模型處理輸入序列的每個(gè)單詞,自注意力會(huì)關(guān)注整個(gè)輸入序列的所有單詞,幫助模型對(duì)本單詞更好地進(jìn)行編碼。實(shí)際上訓(xùn)練完成后確實(shí)如此,google提供了可視化工具,如下所示:
上述是從宏觀角度思考,如果從輸入輸出流角度思考,也比較容易:
假設(shè)我們現(xiàn)在要翻譯上述兩個(gè)單詞,首先將單詞進(jìn)行編碼,和位置編碼向量相加,得到自注意力層輸入X,其shape為(b,N,512);然后定義三個(gè)可學(xué)習(xí)矩陣??(通過(guò)nn.Linear實(shí)現(xiàn)),其shape為(512,M),一般M等于前面維度512,從而計(jì)算后維度不變;將X和矩陣?相乘,得到QKV輸出,shape為(b,N,M);然后將Q和K進(jìn)行點(diǎn)乘計(jì)算向量相似性;采用softmax轉(zhuǎn)換為概率分布;將概率分布和V進(jìn)行加權(quán)求和即可。其可視化如下:
上述繪制的不是矩陣形式,更好理解而已。對(duì)于第一個(gè)單詞的編碼過(guò)程是:將q1和所有的k進(jìn)行相似性計(jì)算,然后除以維度的平方根(論文中是64,本文可以認(rèn)為是512)使得梯度更加穩(wěn)定,然后通過(guò)softmax傳遞結(jié)果,這個(gè)softmax分?jǐn)?shù)決定了每個(gè)單詞對(duì)編碼當(dāng)下位置(“Thinking”)的貢獻(xiàn),最后對(duì)加權(quán)值向量求和得到z1。
這個(gè)計(jì)算很明顯就是前面說(shuō)的注意力機(jī)制計(jì)算過(guò)程,每個(gè)輸入單詞的編碼輸出都會(huì)通過(guò)注意力機(jī)制引入其余單詞的編碼信息。
上述為了方便理解才拆分這么細(xì)致,實(shí)際上代碼層面采用矩陣實(shí)現(xiàn)非常簡(jiǎn)單:
上面的操作很不錯(cuò),但是還有改進(jìn)空間,論文中又增加一種叫做“多頭”注意力(“multi-headed” attention)的機(jī)制進(jìn)一步完善了自注意力層,并在兩方面提高了注意力層的性能:
它擴(kuò)展了模型專注于不同位置的能力。在上面的例子中,雖然每個(gè)編碼都在z1中有或多或少的體現(xiàn),但是它可能被實(shí)際的單詞本身所支配。如果我們翻譯一個(gè)句子,比如“The animal didn’t cross the street because it was too tired”,我們會(huì)想知道“it”指的是哪個(gè)詞,這時(shí)模型的“多頭”注意機(jī)制會(huì)起到作用。
它給出了注意力層的多個(gè)“表示子空間",對(duì)于“多頭”注意機(jī)制,有多個(gè)查詢/鍵/值權(quán)重矩陣集(Transformer使用8個(gè)注意力頭,因此我們對(duì)于每個(gè)編碼器/解碼器有8個(gè)矩陣集合)。
簡(jiǎn)單來(lái)說(shuō)就是類似于分組操作,將輸入X分別輸入到8個(gè)attention層中,得到8個(gè)Z矩陣輸出,最后對(duì)結(jié)果concat即可。論文圖示如下:
先忽略Mask的作用,左邊是單頭attention操作,右邊是n個(gè)單頭attention構(gòu)成的多頭自注意力層。
代碼層面非常簡(jiǎn)單,單頭attention操作如下:
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
# self.temperature是論文中的d_k ** 0.5,防止梯度過(guò)大
# QxK/sqrt(dk)
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
# 屏蔽不想要的輸出
attn = attn.masked_fill(mask == 0, -1e9)
# softmax+dropout
attn = self.dropout(F.softmax(attn, dim=-1))
# 概率分布xV
output = torch.matmul(attn, v)
return output, attn
再次復(fù)習(xí)下Multi-Head Attention層的圖示,可以發(fā)現(xiàn)在前面講的內(nèi)容基礎(chǔ)上還加入了殘差設(shè)計(jì)和層歸一化操作,目的是為了防止梯度消失,加快收斂。
Multi-Head Attention實(shí)現(xiàn)在ScaledDotProductAttention基礎(chǔ)上構(gòu)建:
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
# n_head頭的個(gè)數(shù),默認(rèn)是8
# d_model編碼向量長(zhǎng)度,例如本文說(shuō)的512
# d_k, d_v的值一般會(huì)設(shè)置為 n_head * d_k=d_model,
# 此時(shí)concat后正好和原始輸入一樣,當(dāng)然不相同也可以,因?yàn)楹竺嬗衒c層
# 相當(dāng)于將可學(xué)習(xí)矩陣分成獨(dú)立的n_head份
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
# 假設(shè)n_head=8,d_k=64
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# d_model輸入向量,n_head * d_k輸出向量
# 可學(xué)習(xí)W^Q,W^K,W^V矩陣參數(shù)初始化
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
# 最后的輸出維度變換操作
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
# 單頭自注意力
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
self.dropout = nn.Dropout(dropout)
# 層歸一化
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v, mask=None):
# 假設(shè)qkv輸入是(b,100,512),100是訓(xùn)練每個(gè)樣本最大單詞個(gè)數(shù)
# 一般qkv相等,即自注意力
residual = q
# 將輸入x和可學(xué)習(xí)矩陣相乘,得到(b,100,512)輸出
# 其中512的含義其實(shí)是8x64,8個(gè)head,每個(gè)head的可學(xué)習(xí)矩陣為64維度
# q的輸出是(b,100,8,64),kv也是一樣
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# 變成(b,8,100,64),方便后面計(jì)算,也就是8個(gè)頭單獨(dú)計(jì)算
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.
# 輸出q是(b,8,100,64),維持不變,內(nèi)部計(jì)算流程是:
# q*k轉(zhuǎn)置,除以d_k ** 0.5,輸出維度是b,8,100,100即單詞和單詞直接的相似性
# 對(duì)最后一個(gè)維度進(jìn)行softmax操作得到b,8,100,100
# 最后乘上V,得到b,8,100,64輸出
q, attn = self.attention(q, k, v, mask=mask)
# b,100,8,64-->b,100,512
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
# 殘差計(jì)算
q += residual
# 層歸一化,在512維度計(jì)算均值和方差,進(jìn)行層歸一化
q = self.layer_norm(q)
return q, attn
現(xiàn)在pytorch新版本已經(jīng)把MultiHeadAttention當(dāng)做nn中的一個(gè)類了,可以直接調(diào)用。
(2) 前饋神經(jīng)網(wǎng)絡(luò)層
這個(gè)層就沒(méi)啥說(shuō)的了,非常簡(jiǎn)單:
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
# 兩個(gè)fc層,對(duì)最后的512維度進(jìn)行變換
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.w_2(F.relu(self.w_1(x)))
x = self.dropout(x)
x += residual
x = self.layer_norm(x)
return x
(3) 編碼層操作整體流程
可視化如下所示:
單個(gè)編碼層代碼如下所示:
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
# Q K V是同一個(gè),自注意力
# enc_input來(lái)自源單詞嵌入向量或者前一個(gè)編碼器輸出
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn
將上述編碼過(guò)程重復(fù)n遍即可,除了第一個(gè)模塊輸入是單詞嵌入向量與位置編碼的和外,其余編碼層輸入是上一個(gè)編碼器輸出即后面的編碼器輸入不需要位置編碼向量。如果考慮n個(gè)編碼器的運(yùn)行過(guò)程,如下所示:
class Encoder(nn.Module):
def __init__(
self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
d_model, d_inner, pad_idx, dropout=0.1, n_position=200):
# nlp領(lǐng)域的詞嵌入向量生成過(guò)程(單詞在詞表里面的索引idx-->d_word_vec長(zhǎng)度的向量)
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
# 位置編碼
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
# n個(gè)編碼器層
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
# 層歸一化
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, src_seq, src_mask, return_attns=False):
# 對(duì)輸入序列進(jìn)行詞嵌入,加上位置編碼
enc_output = self.dropout(self.position_enc(self.src_word_emb(src_seq)))
enc_output = self.layer_norm(enc_output)
# 作為編碼器層輸入
for enc_layer in self.layer_stack:
enc_output, _ = enc_layer(enc_output, slf_attn_mask=src_mask)
return enc_output
到目前為止我們就講完了編碼部分的全部流程和代碼細(xì)節(jié)?,F(xiàn)在再來(lái)看整個(gè)transformer算法就會(huì)感覺(jué)親切很多了:
1.4.3 解碼器輸入數(shù)據(jù)處理
在分析解碼器結(jié)構(gòu)前先看下解碼器整體結(jié)構(gòu),方便理解:
其輸入數(shù)據(jù)處理也要區(qū)分第一個(gè)解碼器和后續(xù)解碼器,和編碼器類似,第一個(gè)解碼器輸入不僅包括最后一個(gè)編碼器輸出,還需要額外的輸出嵌入向量,而后續(xù)解碼器輸入是來(lái)自最后一個(gè)編碼器輸出和前面解碼器輸出。
(1) 目標(biāo)單詞嵌入
這個(gè)操作和源單詞嵌入過(guò)程完全相同,維度也是512,假設(shè)輸出是i am a student,那么需要對(duì)這4個(gè)單詞也利用word2vec算法轉(zhuǎn)化為4x512的矩陣,作為第一個(gè)解碼器的單詞嵌入輸入。
(2) 位置編碼
同樣的也需要對(duì)解碼器輸入引入位置編碼,做法和編碼器部分完全相同,且將目標(biāo)單詞嵌入向量和位置編碼向量相加即可作為第一個(gè)解碼器輸入。
和編碼器單詞嵌入不同的地方是在進(jìn)行目標(biāo)單詞嵌入前,還需要將目標(biāo)單詞即是i am a student右移動(dòng)一位,新增加的一個(gè)位置采用提前定義好的標(biāo)志位BOS_WORD代替,現(xiàn)在就變成[BOS_WORD,i,am,a,student],為啥要右移?因?yàn)榻獯a過(guò)程和seq2seq一樣是順序解碼的,需要提供一個(gè)開(kāi)始解碼標(biāo)志,。不然第一個(gè)時(shí)間步的解碼單詞i是如何輸出的呢?具體解碼過(guò)程其實(shí)是:輸入BOS_WORD,解碼器輸出i;輸入前面已經(jīng)解碼的BOS_WORD和i,解碼器輸出am...,輸入已經(jīng)解碼的BOS_WORD、i、am、a和student,解碼器輸出解碼結(jié)束標(biāo)志位EOS_WORD,每次解碼都會(huì)利用前面已經(jīng)解碼輸出的所有單詞嵌入信息
下面有個(gè)非常清晰的gif圖,一目了然:
上圖沒(méi)有繪制BOS_WORD嵌入向量輸入,然后解碼出i單詞的過(guò)程。
1.4.4 解碼器前向過(guò)程
仔細(xì)觀察解碼器結(jié)構(gòu),其包括:帶有mask的MultiHeadAttention、MultiHeadAttention和前饋神經(jīng)網(wǎng)絡(luò)層三個(gè)組件,帶有mask的MultiHeadAttention和MultiHeadAttention結(jié)構(gòu)和代碼寫(xiě)法是完全相同,唯一區(qū)別是是否輸入了mask。
為啥要mask?原因依然是順序解碼導(dǎo)致的。試想模型訓(xùn)練好了,開(kāi)始進(jìn)行翻譯(測(cè)試),其流程就是上面寫(xiě)的:輸入BOS_WORD,解碼器輸出i;輸入前面已經(jīng)解碼的BOS_WORD和i,解碼器輸出am...,輸入已經(jīng)解碼的BOS_WORD、i、am、a和student,解碼器輸出解碼結(jié)束標(biāo)志位EOS_WORD,每次解碼都會(huì)利用前面已經(jīng)解碼輸出的所有單詞嵌入信息,這個(gè)測(cè)試過(guò)程是沒(méi)有問(wèn)題,但是訓(xùn)練時(shí)候我肯定不想采用上述順序解碼類似rnn即一個(gè)一個(gè)目標(biāo)單詞嵌入向量順序輸入訓(xùn)練,肯定想采用類似編碼器中的矩陣并行算法,一步就把所有目標(biāo)單詞預(yù)測(cè)出來(lái)。要實(shí)現(xiàn)這個(gè)功能就可以參考編碼器的操作,把目標(biāo)單詞嵌入向量組成矩陣一次輸入即可,但是在解碼am時(shí)候,不能利用到后面單詞a和student的目標(biāo)單詞嵌入向量信息,否則這就是作弊(測(cè)試時(shí)候不可能能未卜先知)。為此引入mask,目的是構(gòu)成下三角矩陣,右上角全部設(shè)置為負(fù)無(wú)窮(相當(dāng)于忽略),從而實(shí)現(xiàn)當(dāng)解碼第一個(gè)字的時(shí)候,第一個(gè)字只能與第一個(gè)字計(jì)算相關(guān)性,當(dāng)解出第二個(gè)字的時(shí)候,只能計(jì)算出第二個(gè)字與第一個(gè)字和第二個(gè)字的相關(guān)性。具體是:在解碼器中,自注意力層只被允許處理輸出序列中更靠前的那些位置,在softmax步驟前,它會(huì)把后面的位置給隱去(把它們?cè)O(shè)為-inf)。
還有個(gè)非常重要點(diǎn)需要知道(看圖示可以發(fā)現(xiàn)):解碼器內(nèi)部的帶有mask的MultiHeadAttention的qkv向量輸入來(lái)自目標(biāo)單詞嵌入或者前一個(gè)解碼器輸出,三者是相同的,但是后面的MultiHeadAttention的qkv向量中的kv來(lái)自最后一層編碼器的輸入,而q來(lái)自帶有mask的MultiHeadAttention模塊的輸出。
關(guān)于帶mask的注意力層寫(xiě)法其實(shí)就是前面提到的代碼:
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
# 假設(shè)q是b,8,10,64(b是batch,8是head個(gè)數(shù),10是樣本最大單詞長(zhǎng)度,
# 64是每個(gè)單詞的編碼向量)
# attn輸出維度是b,8,10,10
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
# 故mask維度也是b,8,10,10
# 忽略b,8,只關(guān)注10x10的矩陣,其是下三角矩陣,下三角位置全1,其余位置全0
if mask is not None:
# 提前算出mask,將為0的地方變成極小值-1e9,把這些位置的值設(shè)置為忽略
# 目的是避免解碼過(guò)程中利用到未來(lái)信息
attn = attn.masked_fill(mask == 0, -1e9)
# softmax+dropout
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output, attn
可視化如下:圖片來(lái)源https://zhuanlan.zhihu.com/p/44731789
整個(gè)解碼器代碼和編碼器非常類似:
class DecoderLayer(nn.Module):
''' Compose with three layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(DecoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(
self, dec_input, enc_output,
slf_attn_mask=None, dec_enc_attn_mask=None):
# 標(biāo)準(zhǔn)的自注意力,QKV=dec_input來(lái)自目標(biāo)單詞嵌入或者前一個(gè)解碼器輸出
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask)
# KV來(lái)自最后一個(gè)編碼層輸出enc_output,Q來(lái)自帶有mask的self.slf_attn輸出
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn
考慮n個(gè)解碼器模塊,其整體流程為:
class Decoder(nn.Module):
def __init__(
self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
d_model, d_inner, pad_idx, n_position=200, dropout=0.1):
# 目標(biāo)單詞嵌入
self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
# 位置嵌入向量
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
# n個(gè)解碼器
self.layer_stack = nn.ModuleList([
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
# 層歸一化
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False):
# 目標(biāo)單詞嵌入+位置編碼
dec_output = self.dropout(self.position_enc(self.trg_word_emb(trg_seq)))
dec_output = self.layer_norm(dec_output)
# 遍歷每個(gè)解碼器
for dec_layer in self.layer_stack:
# 需要輸入3個(gè)信息:目標(biāo)單詞嵌入+位置編碼、最后一個(gè)編碼器輸出enc_output
# 和dec_enc_attn_mask,解碼時(shí)候不能看到未來(lái)單詞信息
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)
return dec_output
1.4.5 分類層
在進(jìn)行編碼器-解碼器后輸出依然是向量,需要在后面接fc+softmax層進(jìn)行分類訓(xùn)練。假設(shè)當(dāng)前訓(xùn)練過(guò)程是翻譯任務(wù)需要輸出i am a student EOS_WORD這5個(gè)單詞。假設(shè)我們的模型是從訓(xùn)練集中學(xué)習(xí)一萬(wàn)個(gè)不同的英語(yǔ)單詞(我們模型的“輸出詞表”)。因此softmax后輸出為一萬(wàn)個(gè)單元格長(zhǎng)度的向量,每個(gè)單元格對(duì)應(yīng)某一個(gè)單詞的分?jǐn)?shù),這其實(shí)就是普通多分類問(wèn)題,只不過(guò)維度比較大而已。
依然以前面例子為例,假設(shè)編碼器輸出shape是(b,100,512),經(jīng)過(guò)fc后變成(b,100,10000),然后對(duì)最后一個(gè)維度進(jìn)行softmax操作,得到bx100個(gè)單詞的概率分布,在訓(xùn)練過(guò)程中bx100個(gè)單詞是知道label的,故可以直接采用ce loss進(jìn)行訓(xùn)練。
self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)
dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask)
return F.softmax(self.model.trg_word_prj(dec_output), dim=-1)
1.4.6 前向流程
以翻譯任務(wù)為例:
將源單詞進(jìn)行嵌入,組成矩陣(加上位置編碼矩陣)輸入到n個(gè)編碼器中,輸出編碼向量KV
第一個(gè)解碼器先輸入一個(gè)BOS_WORD單詞嵌入向量,后續(xù)解碼器接受該解碼器輸出,結(jié)合KV進(jìn)行第一次解碼
將第一次解碼單詞進(jìn)行嵌入,聯(lián)合BOS_WORD單詞嵌入向量構(gòu)成矩陣再次輸入到解碼器中進(jìn)行第二次解碼,得到解碼單詞
不斷循環(huán),每次的第一個(gè)解碼器輸入都不同,其包含了前面時(shí)間步長(zhǎng)解碼出的所有單詞
直到輸出EOS_WORD表示解碼結(jié)束或者強(qiáng)制設(shè)置最大時(shí)間步長(zhǎng)即可
這個(gè)解碼過(guò)程其實(shí)就是標(biāo)準(zhǔn)的seq2seq流程。到目前為止就描述完了整個(gè)標(biāo)準(zhǔn)transformer訓(xùn)練和測(cè)試流程。
2 視覺(jué)領(lǐng)域的transformer
在理解了標(biāo)準(zhǔn)的transformer后,再來(lái)看視覺(jué)領(lǐng)域transformer就會(huì)非常簡(jiǎn)單,因?yàn)樵赾v領(lǐng)域應(yīng)用transformer時(shí)候大家都有一個(gè)共識(shí):盡量不改動(dòng)transformer結(jié)構(gòu),這樣才能和NLP領(lǐng)域發(fā)展對(duì)齊,所以大家理解cv里面的transformer操作是非常簡(jiǎn)單的。
2.1 分類vision transformer
論文題目:An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale
論文地址:https://arxiv.org/abs/2010.11929
github:?https://github.com/lucidrains/vit-pytorch
其做法超級(jí)簡(jiǎn)單,只含有編碼器模塊:
本文出發(fā)點(diǎn)是徹底拋棄CNN,以前的cv領(lǐng)域雖然引入transformer,但是或多或少都用到了cnn或者rnn,本文就比較純粹了,整個(gè)算法幾句話就說(shuō)清楚了,下面直接分析。
2.1.1 圖片分塊和降維
因?yàn)閠ransformer的輸入需要序列,所以最簡(jiǎn)單做法就是把圖片切分為patch,然后拉成序列即可。假設(shè)輸入圖片大小是256x256,打算分成64個(gè)patch,每個(gè)patch是32x32像素
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
這個(gè)寫(xiě)法是采用了愛(ài)因斯坦表達(dá)式,具體是采用了einops庫(kù)實(shí)現(xiàn),內(nèi)部集成了各種算子,rearrange就是其中一個(gè),非常高效。不懂這種語(yǔ)法的請(qǐng)自行百度。p就是patch大小,假設(shè)輸入是b,3,256,256,則rearrange操作是先變成(b,3,8x32,8x32),最后變成(b,8x8,32x32x3)即(b,64,3072),將每張圖片切分成64個(gè)小塊,每個(gè)小塊長(zhǎng)度是32x32x3=3072,也就是說(shuō)輸入長(zhǎng)度為64的圖像序列,每個(gè)元素采用3072長(zhǎng)度進(jìn)行編碼。
考慮到3072有點(diǎn)大,故作者先進(jìn)行降維:
# 將3072變成dim,假設(shè)是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)
仔細(xì)看論文上圖,可以發(fā)現(xiàn)假設(shè)切成9個(gè)塊,但是最終到transfomer輸入是10個(gè)向量,額外追加了一個(gè)0和_。為啥要追加?原因是我們現(xiàn)在沒(méi)有解碼器了,而是編碼后直接就進(jìn)行分類預(yù)測(cè),那么該解碼器就要負(fù)責(zé)一點(diǎn)點(diǎn)解碼器功能,那就是:需要一個(gè)類似開(kāi)啟解碼標(biāo)志,非常類似于標(biāo)準(zhǔn)transformer解碼器中輸入的目標(biāo)嵌入向量右移一位操作。試下如果沒(méi)有額外輸入,9個(gè)塊輸入9個(gè)編碼向量輸出,那么對(duì)于分類任務(wù)而言,我應(yīng)該取哪個(gè)輸出向量進(jìn)行后續(xù)分類呢?選擇任何一個(gè)都說(shuō)不通,所以作者追加了一個(gè)可學(xué)習(xí)嵌入向量輸入。那么額外的可學(xué)習(xí)嵌入向量為啥要設(shè)計(jì)為可學(xué)習(xí),而不是類似nlp中采用固定的token代替?個(gè)人不負(fù)責(zé)任的猜測(cè)這應(yīng)該就是圖片領(lǐng)域和nlp領(lǐng)域的差別,nlp里面每個(gè)詞其實(shí)都有具體含義,是離散的,但是圖像領(lǐng)域沒(méi)有這種真正意義上的離散token,有的只是一堆連續(xù)特征或者圖像像素,如果不設(shè)置為可學(xué)習(xí),那還真不知道應(yīng)該設(shè)置為啥內(nèi)容比較合適,全0和全1也說(shuō)不通。自此現(xiàn)在就是變成10個(gè)向量輸出,輸出也是10個(gè)編碼向量,然后取第0個(gè)編碼輸出進(jìn)行分類預(yù)測(cè)即可。從這個(gè)角度看可以認(rèn)為編碼器多了一點(diǎn)點(diǎn)解碼器功能。具體做法超級(jí)簡(jiǎn)單,0就是位置編碼向量,_是可學(xué)習(xí)的patch嵌入向量。
# dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
# 變成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 額外追加token,變成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)
2.1.2 位置編碼
位置編碼也是必不可少的,長(zhǎng)度應(yīng)該是1024,這里做的比較簡(jiǎn)單,沒(méi)有采用sincos編碼,而是直接設(shè)置為可學(xué)習(xí),效果差不多
# num_patches=64,dim=1024,+1是因?yàn)槎嗔艘粋€(gè)cls開(kāi)啟解碼標(biāo)志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
對(duì)訓(xùn)練好的pos_embedding進(jìn)行可視化,如下所示:
相鄰位置有相近的位置編碼向量,整體呈現(xiàn)2d空間位置排布一樣。
將patch嵌入向量和位置編碼向量相加即可作為編碼器輸入
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
2.1.3 編碼器前向過(guò)程
作者采用的是沒(méi)有任何改動(dòng)的transformer,故沒(méi)有啥說(shuō)的。
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
假設(shè)輸入是(b,65,1024),那么transformer輸出也是(b,65,1024)
2.1.4 分類head
在編碼器后接fc分類器head即可
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)
# 65個(gè)輸出里面只需要第0個(gè)輸出進(jìn)行后續(xù)分類即可
self.mlp_head(x[:, 0])
到目前為止就全部寫(xiě)完了,是不是非常簡(jiǎn)單,外層整體流程為:
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dropout=0.,emb_dropout=0.):
super().__init__()
# image_size輸入圖片大小 256
# patch_size 每個(gè)patch的大小 32
num_patches = (image_size // patch_size) ** 2 # 一共有多少個(gè)patch 8x8=64
patch_dim = channels * patch_size ** 2 # 3x32x32=3072
self.patch_size = patch_size # 32
# 1,64+1,1024,+1是因?yàn)閠oken,可學(xué)習(xí)變量,不是固定編碼
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# 圖片維度太大了,需要先降維
self.patch_to_embedding = nn.Linear(patch_dim, dim)
# 分類輸出位置標(biāo)志,否則分類輸出不知道應(yīng)該取哪個(gè)位置
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
# 編碼器
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
# 輸出頭
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes)
)
def forward(self, img, mask=None):
p = self.patch_size
# 先把圖片變成64個(gè)patch,輸出shape=b,64,3072
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
# 輸出 b,64,1024
x = self.patch_to_embedding(x)
b, n, _ = x.shape
# 輸出 b,1,1024
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 額外追加token,變成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)
# 加上位置編碼1,64+1,1024
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
x = self.transformer(x, mask)
# 分類head,只需要x[0]即可
# x = self.to_cls_token(x[:, 0])
x = x[:, 0]
return self.mlp_head(x)
2.1.5 實(shí)驗(yàn)分析
作者得出的結(jié)論是:cv領(lǐng)域應(yīng)用transformer需要大量數(shù)據(jù)進(jìn)行預(yù)訓(xùn)練,在同等數(shù)據(jù)量的情況下性能不然cnn。一旦數(shù)據(jù)量上來(lái)了,對(duì)應(yīng)的訓(xùn)練時(shí)間也會(huì)加長(zhǎng)很多,那么就可以輕松超越cnn。
?
同時(shí)應(yīng)用transformer,一個(gè)突出優(yōu)點(diǎn)是可解釋性比較強(qiáng):
2.2 目標(biāo)檢測(cè)detr
論文名稱:End-to-End Object Detection with Transformers
論文地址:https://arxiv.org/abs/2005.12872
github:https://github.com/facebookresearch/detr
detr是facebook提出的引入transformer到目標(biāo)檢測(cè)領(lǐng)域的算法,效果很好,做法也很簡(jiǎn)單,符合其一貫的簡(jiǎn)潔優(yōu)雅設(shè)計(jì)做法。
對(duì)于目標(biāo)檢測(cè)任務(wù),其要求輸出給定圖片中所有前景物體的類別和bbox坐標(biāo),該任務(wù)實(shí)際上是無(wú)序集合預(yù)測(cè)問(wèn)題。針對(duì)該問(wèn)題,detr做法非常簡(jiǎn)單:給定一張圖片,經(jīng)過(guò)CNN進(jìn)行特征提取,然后變成特征序列輸入到transformer的編解碼器中,直接輸出指定長(zhǎng)度為N的無(wú)序集合,集合中每個(gè)元素包含物體類別和坐標(biāo)。其中N表示整個(gè)數(shù)據(jù)集中圖片上最多物體的數(shù)目,因?yàn)檎麄€(gè)訓(xùn)練和測(cè)試都Batch進(jìn)行,如果不設(shè)置最大輸出集合數(shù),無(wú)法進(jìn)行batch訓(xùn)練,如果圖片中物體不夠N個(gè),那么就采用no object填充,表示該元素是背景。
整個(gè)思想看起來(lái)非常簡(jiǎn)單,相比f(wàn)aster rcnn或者yolo算法那就簡(jiǎn)單太多了,因?yàn)槠洳恍枰O(shè)置先驗(yàn)anchor,超參幾乎沒(méi)有,也不需要nms(因?yàn)檩敵龅臒o(wú)序集合沒(méi)有重復(fù)情況),并且在代碼程度相比f(wàn)aster rcnn那就不知道簡(jiǎn)單多少倍了,通過(guò)簡(jiǎn)單修改就可以應(yīng)用于全景分割任務(wù)??梢酝茰y(cè),如果transformer真正大規(guī)模應(yīng)用于CV領(lǐng)域,那么對(duì)初學(xué)者來(lái)說(shuō)就是福音了,理解transformer就幾乎等于理解了整個(gè)cv領(lǐng)域了(當(dāng)然也可能是壞事)。
2.2.1 detr核心思想分析
相比f(wàn)aster rcnn等做法,detr最大特點(diǎn)是將目標(biāo)檢測(cè)問(wèn)題轉(zhuǎn)化為無(wú)序集合預(yù)測(cè)問(wèn)題。論文中特意指出faster rcnn這種設(shè)置一大堆a(bǔ)nchor,然后基于anchor進(jìn)行分類和回歸其實(shí)屬于代理做法即不是最直接做法,目標(biāo)檢測(cè)任務(wù)就是輸出無(wú)序集合,而faster rcnn等算法通過(guò)各種操作,并結(jié)合復(fù)雜后處理最終才得到無(wú)序集合屬于繞路了,而detr就比較純粹了。
盡管將transformer引入目標(biāo)檢測(cè)領(lǐng)域可以避免上述各種問(wèn)題,但是其依然存在兩個(gè)核心操作:
無(wú)序集合輸出的loss計(jì)算
針對(duì)目標(biāo)檢測(cè)的transformer改進(jìn)
2.2.2 detr算法實(shí)現(xiàn)細(xì)節(jié)
下面結(jié)合代碼和原理對(duì)其核心環(huán)節(jié)進(jìn)行深入分析。
2.2.2.1 無(wú)序集合輸出的loss計(jì)算
在分析loss計(jì)算前,需要先明確N個(gè)無(wú)序集合的target構(gòu)建方式。作者在coco數(shù)據(jù)集上統(tǒng)計(jì),一張圖片最多標(biāo)注了63個(gè)物體,所以N應(yīng)該要不小于63,作者設(shè)置的是100。為啥要設(shè)置為100?有人猜測(cè)是和coco評(píng)估指標(biāo)只取前100個(gè)預(yù)測(cè)結(jié)果算法指標(biāo)有關(guān)系。
detr輸出是包括batchx100個(gè)無(wú)序集合,每個(gè)集合包括類別和坐標(biāo)信息。對(duì)于coco數(shù)據(jù)而言,作者設(shè)置類別為91(coco類別標(biāo)注索引是1-91,但是實(shí)際就標(biāo)注了80個(gè)類別),加上背景一共92個(gè)類別,對(duì)于坐標(biāo)分支采用4個(gè)歸一化值表征即cxcywh中心點(diǎn)、wh坐標(biāo),然后除以圖片寬高進(jìn)行歸一化(沒(méi)有采用復(fù)雜變換策略),故每個(gè)集合是??,c是長(zhǎng)度為92的分類向量,b是長(zhǎng)度為4的bbox坐標(biāo)向量??傊甦etr輸出集合包括兩個(gè)分支:分類分支shape=(b,100,92),bbox坐標(biāo)分支shape=(b,100,4),對(duì)應(yīng)的target也是包括分類target和bbox坐標(biāo)target,如果不夠100,則采用背景填充,計(jì)算loss時(shí)候bbox分支僅僅計(jì)算有物體位置,背景集合忽略。
現(xiàn)在核心問(wèn)題來(lái)了:輸出的bx100個(gè)檢測(cè)結(jié)果是無(wú)序的,如何和gt bbox計(jì)算loss?這就需要用到經(jīng)典的雙邊匹配算法了,也就是常說(shuō)的匈牙利算法,該算法廣泛應(yīng)用于最優(yōu)分配問(wèn)題,在bottom-up人體姿態(tài)估計(jì)算法中進(jìn)行分組操作時(shí)候也經(jīng)常使用。detr中利用匈牙利算法先進(jìn)行最優(yōu)一對(duì)一匹配得到匹配索引,然后對(duì)bx100個(gè)結(jié)果進(jìn)行重排就和gt bbox對(duì)應(yīng)上了(對(duì)gt bbox進(jìn)行重排也可以,沒(méi)啥區(qū)別),就可以算loss了。
匈牙利算法是一個(gè)標(biāo)準(zhǔn)優(yōu)化算法,具體是組合優(yōu)化算法,在scipy.optimize.linear_sum_assignmen函數(shù)中有實(shí)現(xiàn),一行代碼就可以得到最優(yōu)匹配,網(wǎng)上解讀也非常多,這里就不寫(xiě)細(xì)節(jié)了,該函數(shù)核心是需要輸入A集合和B集合兩兩元素之間的連接權(quán)重,基于該重要性進(jìn)行內(nèi)部最優(yōu)匹配,連接權(quán)重大的優(yōu)先匹配。
上述描述優(yōu)化過(guò)程可以采用如下公式表達(dá):
優(yōu)化對(duì)象是??,其是長(zhǎng)度為N的list,??,???表示無(wú)序gt bbox集合的哪個(gè)元素和輸出預(yù)測(cè)集合中的第i個(gè)匹配。其實(shí)簡(jiǎn)單來(lái)說(shuō)就是找到最優(yōu)匹配,因?yàn)樵谧罴哑ヅ淝闆r下l_match和最小即loss最小。
前面說(shuō)過(guò)匈牙利算法核心是需要提供輸入A集合和B集合兩兩元素之間的連接權(quán)重,這里就是要輸入N個(gè)輸出集合和M個(gè)gt bbox之間的關(guān)聯(lián)程度,如下所示
而Lbox具體是:
Hungarian意思就是匈牙利,也就是前面的L_match,上述意思是需要計(jì)算M個(gè)gt bbox和N個(gè)輸出集合兩兩之間的廣義距離,距離越近表示越可能是最優(yōu)匹配關(guān)系,也就是兩者最密切。廣義距離的計(jì)算考慮了分類分支和bbox分支,下面結(jié)合代碼直接說(shuō)明,比較簡(jiǎn)單。
# detr分類輸出,num_queries=100,shape是(b,100,92)
bs, num_queries = outputs["pred_logits"].shape[:2]
# 得到概率輸出(bx100,92)
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)
# 得到bbox分支輸出(bx100,4)
out_bbox = outputs["pred_boxes"].flatten(0, 1)
# 準(zhǔn)備分類target shape=(m,)里面存儲(chǔ)的是類別索引,m包括了整個(gè)batch內(nèi)部的所有g(shù)t bbox
tgt_ids = torch.cat([v["labels"] for v in targets])
# 準(zhǔn)備bbox target shape=(m,4),已經(jīng)歸一化了
tgt_bbox = torch.cat([v["boxes"] for v in targets])
#核心
#bx100,92->bx100,m,對(duì)于每個(gè)預(yù)測(cè)結(jié)果,把目前gt里面有的所有類別值提取出來(lái),其余值不需要參與匹配
#對(duì)應(yīng)上述公式,類似于nll loss,但是更加簡(jiǎn)單
cost_class = -out_prob[:, tgt_ids]
#計(jì)算out_bbox和tgt_bbox兩兩之間的l1距離 bx100,m
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
#額外多計(jì)算一個(gè)giou loss bx100,m
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
#得到最終的廣義距離bx100,m,距離越小越可能是最優(yōu)匹配
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
# bx100,m--> batch,100,m
C = C.view(bs, num_queries, -1).cpu()
#計(jì)算每個(gè)batch內(nèi)部有多少物體,后續(xù)計(jì)算時(shí)候按照單張圖片進(jìn)行匹配,沒(méi)必要batch級(jí)別匹配,徒增計(jì)算
sizes = [len(v["boxes"]) for v in targets]
#匈牙利最優(yōu)匹配,返回匹配索引
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
在得到匹配關(guān)系后算loss就水到渠成了。分類分支計(jì)算ce loss,bbox分支計(jì)算l1 loss+giou loss
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
#shape是(b,100,92)
src_logits = outputs['pred_logits']
#得到匹配后索引,作用在label上
idx = self._get_src_permutation_idx(indices)
#得到匹配后的分類target
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
#加入背景(self.num_classes),補(bǔ)齊bx100個(gè)
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
#shape是(b,100,),存儲(chǔ)的是索引,不是one-hot
target_classes[idx] = target_classes_o
#計(jì)算ce loss,self.empty_weight前景和背景權(quán)重是1和0.1,克服類別不平衡
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
losses = {'loss_ce': loss_ce}
return losses
def loss_boxes(self, outputs, targets, indices, num_boxes):
idx = self._get_src_permutation_idx(indices)
src_boxes = outputs['pred_boxes'][idx]
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
#l1 loss
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
#giou loss
loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
box_ops.box_cxcywh_to_xyxy(src_boxes),
box_ops.box_cxcywh_to_xyxy(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes
return losses
2.2.2.2 針對(duì)目標(biāo)檢測(cè)的transformer改進(jìn)
分析完訓(xùn)練最關(guān)鍵的:雙邊匹配+loss計(jì)算部分,現(xiàn)在需要考慮在目標(biāo)檢測(cè)算法中transformer如何設(shè)計(jì)?下面按照算法的4個(gè)步驟講解。
transformer細(xì)節(jié)如下:
(1) cnn骨架特征提取
骨架網(wǎng)絡(luò)可以是任何一種,作者選擇resnet50,將最后一個(gè)stage即stride=32的特征圖作為編碼器輸入。由于resnet僅僅作為一個(gè)小部分且已經(jīng)經(jīng)過(guò)了imagenet預(yù)訓(xùn)練,故和常規(guī)操作一樣,會(huì)進(jìn)行如下操作:
resnet中所有BN都固定,即采用全局均值和方差
resnet的stem和第一個(gè)stage不進(jìn)行參數(shù)更新,即parameter.requires_grad_(False)
backbone的學(xué)習(xí)率小于transformer,lr_backbone=1e-05,其余為0.0001
假設(shè)輸入是(b,c,h,w),則resnet50輸出是(b,1024,h//32,w//32),1024比較大,為了節(jié)省計(jì)算量,先采用1x1卷積降維為256,最后轉(zhuǎn)化為序列格式輸入到transformer中,輸入shape=(h'xw',b,256),h'=h//32
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
# 輸出是(b,256,h//32,w//32)
src=self.input_proj(src)
# 變成序列模式,(h'xw',b,256),256是每個(gè)詞的編碼長(zhǎng)度
src = src.flatten(2).permute(2, 0, 1)
(2) 編碼器設(shè)計(jì)和輸入
編碼器結(jié)構(gòu)設(shè)計(jì)沒(méi)有任何改變,但是輸入改變了。
a) 位置編碼需要考慮2d空間
由于圖像特征是2d特征,故位置嵌入向量也需要考慮xy方向。前面說(shuō)過(guò)編碼方式可以采用sincos,也可以設(shè)置為可學(xué)習(xí),本文采用的依然是sincos模式,和前面說(shuō)的一樣,但是需要考慮xy兩個(gè)方向(前面說(shuō)的序列只有x方向)。
#輸入是b,c,h,w
#tensor_list的類型是NestedTensor,內(nèi)部自動(dòng)附加了mask,
#用于表示動(dòng)態(tài)shape,是pytorch中tensor新特性https://github.com/pytorch/nestedtensor
x = tensor_list.tensors # 原始tensor數(shù)據(jù)
# 附加的mask,shape是b,h,w 全是false
mask = tensor_list.mask
not_mask = ~mask
# 因?yàn)閳D像是2d的,所以位置編碼也分為x,y方向
# 1 1 1 1 .. 2 2 2 2... 3 3 3...
y_embed = not_mask.cumsum(1, dtype=torch.float32)
# 1 2 3 4 ... 1 2 3 4...
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
# 0~127 self.num_pos_feats=128,因?yàn)榍懊孑斎胂蛄渴?56,編碼是一半sin,一半cos
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
# 歸一化
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# 輸出shape=b,h,w,128
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
# 每個(gè)特征圖的xy位置都編碼成256的向量,其中前128是y方向編碼,而128是x方向編碼
return pos # b,n=256,h,w
可以看出對(duì)于h//32,w//32的2d圖像特征,不是類似vision transoformer做法簡(jiǎn)單的將其拉伸為h//32 x w//32,然后從0-n進(jìn)行長(zhǎng)度為256的位置編碼,而是考慮了xy方向同時(shí)編碼,每個(gè)方向各編碼128維向量,這種編碼方式更符合圖像特定。
還有一個(gè)細(xì)節(jié)需要注意:原始transformer的n個(gè)編碼器輸入中,只有第一個(gè)編碼器需要輸入位置編碼向量,但是detr里面對(duì)每個(gè)編碼器都輸入了同一個(gè)位置編碼向量,論文中沒(méi)有寫(xiě)為啥要如此修改。
b) QKV處理邏輯不同
作者設(shè)置編碼器一共6個(gè),并且位置編碼向量?jī)H僅加到QK中,V中沒(méi)有加入位置信息,這個(gè)和原始做法不一樣,原始做法是QKV都加上了位置編碼,論文中也沒(méi)有寫(xiě)為啥要如此修改。
其余地方就完全相同了,故代碼就沒(méi)必要貼了??偨Y(jié)下和原始transformer編碼器不同的地方:
輸入編碼器的位置編碼需要考慮2d空間位置
位置編碼向量需要加入到每個(gè)編碼器中
在編碼器內(nèi)部位置編碼僅僅和QK相加,V不做任何處理
經(jīng)過(guò)6個(gè)編碼器forward后,輸出shape為(h//32xw//32,b,256)。
c) 編碼器部分整體運(yùn)行流程
6個(gè)編碼器整體forward流程如下:
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
# 編碼器copy6份
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
# 內(nèi)部包括6個(gè)編碼器,順序運(yùn)行
# src是圖像特征輸入,shape=hxw,b,256
output = src
for layer in self.layers:
# 每個(gè)編碼器都需要加入pos位置編碼
# 第一個(gè)編碼器輸入來(lái)自圖像特征,后面的編碼器輸入來(lái)自前一個(gè)編碼器輸出
output = layer(output, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, pos=pos)
return output
每個(gè)編碼器內(nèi)部運(yùn)行流程如下:
def forward_post(self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None):
# 和標(biāo)準(zhǔn)做法有點(diǎn)不一樣,src加上位置編碼得到q和k,但是v依然還是src,
# 也就是v和qk不一樣
q = k = src+pos
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
(3) 解碼器設(shè)計(jì)和輸入
解碼器結(jié)構(gòu)設(shè)計(jì)沒(méi)有任何改變,但是輸入也改變了。
a) 新引入Object queries
object queries(shape是(100,256))可以簡(jiǎn)單認(rèn)為是輸出位置編碼,其作用主要是在學(xué)習(xí)過(guò)程中提供目標(biāo)對(duì)象和全局圖像之間的關(guān)系,相當(dāng)于全局注意力,必不可少非常關(guān)鍵。代碼形式上是可學(xué)習(xí)位置編碼矩陣。和編碼器一樣,該可學(xué)習(xí)位置編碼向量也會(huì)輸入到每一個(gè)解碼器中。我們可以嘗試通俗理解:object queries矩陣內(nèi)部通過(guò)學(xué)習(xí)建模了100個(gè)物體之間的全局關(guān)系,例如房間里面的桌子旁邊(A類)一般是放椅子(B類),而不會(huì)是放一頭大象(C類),那么在推理時(shí)候就可以利用該全局注意力更好的進(jìn)行解碼預(yù)測(cè)輸出。
# num_queries=100,hidden_dim=256
self.query_embed = nn.Embedding(num_queries, hidden_dim)
論文中指出object queries作用非常類似faster rcnn中的anchor,只不過(guò)這里是可學(xué)習(xí)的,不是提前設(shè)置好的。
b) 位置編碼也需要
編碼器環(huán)節(jié)采用的sincos位置編碼向量也可以考慮引入,且該位置編碼向量輸入到每個(gè)解碼器的第二個(gè)Multi-Head Attention中,后面有是否需要該位置編碼的對(duì)比實(shí)驗(yàn)。
c) QKV處理邏輯不同
解碼器一共包括6個(gè),和編碼器中QKV一樣,V不會(huì)加入位置編碼。上述說(shuō)的三個(gè)操作,只要看下網(wǎng)絡(luò)結(jié)構(gòu)圖就一目了然了。
d) 一次解碼輸出全部無(wú)序集合
和原始transformer順序解碼操作不同的是,detr一次就把N個(gè)無(wú)序框并行輸出了(因?yàn)槿蝿?wù)是無(wú)序集合,做成順序推理有序輸出沒(méi)有很大必要)。為了說(shuō)明如何實(shí)現(xiàn)該功能,我們需要先回憶下原始transformer的順序解碼過(guò)程:輸入BOS_WORD,解碼器輸出i;輸入前面已經(jīng)解碼的BOS_WORD和i,解碼器輸出am...,輸入已經(jīng)解碼的BOS_WORD、i、am、a和student,解碼器輸出解碼結(jié)束標(biāo)志位EOS_WORD,每次解碼都會(huì)利用前面已經(jīng)解碼輸出的所有單詞嵌入信息?,F(xiàn)在就是一次解碼,故只需要初始化時(shí)候輸入一個(gè)全0的查詢向量A,類似于BOS_WORD作用,然后第一個(gè)解碼器接受該輸入A,解碼輸出向量作為下一個(gè)解碼器輸入,不斷推理即可,最后一層解碼輸出即為我們需要的輸出,不需要在第二個(gè)解碼器輸入時(shí)候考慮BOS_WORD和第一個(gè)解碼器輸出。
總結(jié)下和原始transformer解碼器不同的地方:
額外引入可學(xué)習(xí)的Object queries,相當(dāng)于可學(xué)習(xí)anchor,提供全局注意力
編碼器采用的sincos位置編碼向量也需要輸入解碼器中,并且每個(gè)解碼器都輸入
QKV處理邏輯不同
不需要順序解碼,一次即可輸出N個(gè)無(wú)序集合
e) 解碼器整體運(yùn)行流程
n個(gè)解碼器整體流程如下:
class TransformerDecoder(nn.Module):
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# 首先query_pos是query_embed,可學(xué)習(xí)輸出位置向量shape=100,b,256
# tgt = torch.zeros_like(query_embed),用于進(jìn)行一次性解碼輸出
output = tgt
# 存儲(chǔ)每個(gè)解碼器輸出,后面中繼監(jiān)督需要
intermediate = []
# 編碼每個(gè)解碼器
for layer in self.layers:
# 每個(gè)解碼器都需要輸入query_pos和pos
# memory是最后一個(gè)編碼器輸出
# 每個(gè)解碼器都接受output作為輸入,然后輸出新的output
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.return_intermediate:
return torch.stack(intermediate) # 6個(gè)輸出都返回
return output.unsqueeze(0)
內(nèi)部每個(gè)解碼器運(yùn)行流程為:
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# query_pos首先是可學(xué)習(xí)的,其作用主要是在學(xué)習(xí)過(guò)程中提供目標(biāo)對(duì)象和全局圖像之間的關(guān)系
# 這個(gè)相當(dāng)于全局注意力輸入,是非常關(guān)鍵的
# query_pos是解碼器特有
q = k = tgt+query_pos
# 第一個(gè)自注意力模塊
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
# memory是最后一個(gè)編碼器輸出,pos是和編碼器輸入中完全相同的sincos位置嵌入向量
# 輸入?yún)?shù)是最核心細(xì)節(jié),query是tgt+query_pos,而key是memory+pos
# v直接用memory
tgt2 = self.multihead_attn(query=tgt+query_pos,
key=memory+pos,
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
解碼器最終輸出shape是(6,b,100,256),6是指6個(gè)解碼器的輸出。
(4) 分類和回歸head
在解碼器輸出基礎(chǔ)上構(gòu)建分類和bbox回歸head即可輸出檢測(cè)結(jié)果,比較簡(jiǎn)單:
self.class_embed = nn.Linear(256, 92)
self.bbox_embed = MLP(256, 256, 4, 3)
# hs是(6,b,100,256),outputs_class輸出(6,b,100,92),表示6個(gè)分類分支
outputs_class = self.class_embed(hs)
# 輸出(6,b,100,4),表示6個(gè)bbox坐標(biāo)回歸分支
outputs_coord = self.bbox_embed(hs).sigmoid()
# 取最后一個(gè)解碼器輸出即可,分類輸出(b,100,92),bbox回歸輸出(b,100,4)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
# 除了最后一個(gè)輸出外,其余編碼器輸出都算輔助loss
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
作者實(shí)驗(yàn)發(fā)現(xiàn),如果對(duì)解碼器的每個(gè)輸出都加入輔助的分類和回歸loss,可以提升性能,故作者除了對(duì)最后一個(gè)編碼層的輸出進(jìn)行Loss監(jiān)督外,還對(duì)其余5個(gè)編碼器采用了同樣的loss監(jiān)督,只不過(guò)權(quán)重設(shè)置低一點(diǎn)而已。
(5) 整體推理流程
基于transformer的detr算法,作者特意強(qiáng)調(diào)其突出優(yōu)點(diǎn)是部署代碼不超過(guò)50行,簡(jiǎn)單至極。
當(dāng)然上面是簡(jiǎn)化代碼,和實(shí)際代碼不一樣。具體流程是:
將(b,3,800,1200)圖片輸入到resnet50中進(jìn)行特征提取,輸出shape=(b,1024,25,38)
通過(guò)1x1卷積降維,變成(b,256,25,38)
利用sincos函數(shù)計(jì)算位置編碼
將圖像特征和位置編碼向量相加,作為編碼器輸入,輸出編碼后的向量,shape不變
初始化全0的(100,b,256)的輸出嵌入向量,結(jié)合位置編碼向量和query_embed,進(jìn)行解碼輸出,解碼器輸出shape為(6,b,100,256)
將最后一個(gè)解碼器輸出輸入到分類和回歸head中,得到100個(gè)無(wú)序集合
對(duì)100個(gè)無(wú)序集合進(jìn)行后處理,主要是提取前景類別和對(duì)應(yīng)的bbox坐標(biāo),乘上(800,1200)即可得到最終坐標(biāo),后處理代碼如下:
prob = F.softmax(out_logits, -1)
scores, labels = prob[..., :-1].max(-1)
# convert to [x0, y0, x1, y1] format
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
boxes = boxes * scale_fct[:, None, :]
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
既然訓(xùn)練時(shí)候?qū)?個(gè)解碼器輸出都進(jìn)行了loss監(jiān)督,那么在測(cè)試時(shí)候也可以考慮將6個(gè)解碼器的分類和回歸分支輸出結(jié)果進(jìn)行nms合并,稍微有點(diǎn)性能提升。
2.2.3 實(shí)驗(yàn)分析
(1) 性能對(duì)比
Faster RCNN-DC5是指的resnet的最后一個(gè)stage采用空洞率=stride設(shè)置代替stride,目的是在不進(jìn)行下采樣基礎(chǔ)上擴(kuò)大感受野,輸出特征圖分辨率保持不變。+號(hào)代表采用了額外的技巧提升性能例如giou、多尺度訓(xùn)練和9xepoch訓(xùn)練策略??梢园l(fā)現(xiàn)detr效果稍微好于faster rcnn各種版本,證明了視覺(jué)transformer的潛力。但是可以發(fā)現(xiàn)其小物體檢測(cè)能力遠(yuǎn)遠(yuǎn)低于faster rcnn,這是一個(gè)比較大的弊端。
(2) 各個(gè)模塊分析
編碼器數(shù)目越多效果越好,但是計(jì)算量也會(huì)增加很多,作者最終選擇的是6。
可以發(fā)現(xiàn)解碼器也是越多越好,還可以觀察到第一個(gè)解碼器輸出預(yù)測(cè)效果比較差,增加第二個(gè)解碼器后性能提升非常多。上圖中的NMS操作是指既然我們每個(gè)解碼層都可以輸入無(wú)序集合,那么將所有解碼器無(wú)序集合全部保留,然后進(jìn)行nms得到最終輸出,可以發(fā)現(xiàn)性能稍微有提升,特別是AP50。
作者對(duì)比了不同類型的位置編碼效果,因?yàn)閝uery_embed(output pos)是必不可少的,所以該列沒(méi)有進(jìn)行對(duì)比實(shí)驗(yàn),始終都有,最后一行效果最好,所以作者采用的就是該方案,sine at attn表示每個(gè)注意力層都加入了sine位置編碼,相比僅僅在input增加位置編碼效果更好。
(3) 注意力可視化
前面說(shuō)過(guò)transformer具有很好的可解釋性,故在訓(xùn)練完成后最終提出了幾種可視化形式
a) bbox輸出可視化
這個(gè)就比較簡(jiǎn)單了,直接對(duì)預(yù)測(cè)進(jìn)行后處理即可
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
# 只保留概率大于0.9的bbox
keep = probas.max(-1).values > 0.9
# 還原到原圖,然后繪制即可
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
plot_results(im, probas[keep], bboxes_scaled)
b) 解碼器自注意力層權(quán)重可視化
這里指的是最后一個(gè)解碼器內(nèi)部的第一個(gè)MultiheadAttention的自注意力權(quán)重,其實(shí)就是QK相似性計(jì)算后然后softmax后的輸出可視化,具體是:
# multihead_attn注冊(cè)前向hook,output[1]指的就是softmax后輸出
model.transformer.decoder.layers[-1].multihead_attn.register_forward_hook(
lambda self, input, output: dec_attn_weights.append(output[1])
)
# 假設(shè)輸入是(1,3,800,1066)
outputs = model(img)
# 那么dec_attn_weights是(1,100,850=800//32x1066//32)
# 這個(gè)就是QK相似性計(jì)算后然后softmax后的輸出,即自注意力權(quán)重
dec_attn_weights = dec_attn_weights[0]
# 如果想看哪個(gè)bbox的權(quán)重,則輸入idx即可
dec_attn_weights[0, idx].view(800//32, 1066//32)
c) 編碼器自注意力層權(quán)重可視化
這個(gè)和解碼器操作完全相同。
model.transformer.encoder.layers[-1].self_attn.register_forward_hook(
lambda self, input, output: enc_attn_weights.append(output[1])
)
outputs = model(img)
# 最后一個(gè)編碼器中的自注意力模塊權(quán)重輸出(b,h//32xw//32,h//32xw//32),其實(shí)就是qk計(jì)算然后softmax后的值即(1,25x34=850,850)
enc_attn_weights = enc_attn_weights[0]
# 變成(25, 34, 25, 34)
sattn = enc_attn_weights[0].reshape(shape + shape)
# 想看哪個(gè)特征點(diǎn)位置的注意力
idxs = [(200, 200), (280, 400), (200, 600), (440, 800), ]
for idx_o, ax in zip(idxs, axs):
# 轉(zhuǎn)化到特征圖尺度
idx = (idx_o[0] // fact, idx_o[1] // fact)
# 直接sattn[..., idx[0], idx[1]]即可
ax.imshow(sattn[..., idx[0], idx[1]], cmap='cividis', interpolation='nearest')
2.2.4 小結(jié)
detr整體做法非常簡(jiǎn)單,基本上沒(méi)有改動(dòng)原始transformer結(jié)構(gòu),其顯著優(yōu)點(diǎn)是:不需要設(shè)置啥先驗(yàn),超參也比較少,訓(xùn)練和部署代碼相比f(wàn)aster rcnn算法簡(jiǎn)單很多,理解上也比較簡(jiǎn)單。但是其缺點(diǎn)是:改了編解碼器的輸入,在論文中也沒(méi)有解釋為啥要如此設(shè)計(jì),而且很多操作都是實(shí)驗(yàn)對(duì)比才確定的,比較迷。算法層面訓(xùn)練epoch次數(shù)遠(yuǎn)遠(yuǎn)大于faster rcnn(300epoch),在同等epoch下明顯性能不如faster rcnn,而且訓(xùn)練占用內(nèi)存也大于faster rcnn。
整體而言,雖然效果不錯(cuò),但是整個(gè)做法還是顯得比較原始,很多地方感覺(jué)是嘗試后得到的做法,沒(méi)有很好的解釋性,而且最大問(wèn)題是訓(xùn)練epoch非常大和內(nèi)存占用比較多,對(duì)應(yīng)的就是收斂慢,期待后續(xù)作品。
3 總結(jié)
本文從transformer發(fā)展歷程入手,并且深入介紹了transformer思想和實(shí)現(xiàn)細(xì)節(jié);最后結(jié)合計(jì)算機(jī)視覺(jué)領(lǐng)域的幾篇有典型代表文章進(jìn)行深入分析,希望能夠給cv領(lǐng)域想快速理解transformer的初學(xué)者一點(diǎn)點(diǎn)幫助。
4 參考資料
1 http://jalammar.github.io/illustrated-transformer/
2 https://zhuanlan.zhihu.com/p/54356280
3 https://zhuanlan.zhihu.com/p/44731789
4 https://looperxx.github.io/CS224n-2019-08-Machine%20Translation,%20Sequence-to-sequence%20and%20Attention/
5 https://github.com/lucidrains/vit-pytorch
6 https://github.com/jadore801120/attention-is-all-you-need-pytorch
7 https://github.com/facebookresearch/detr
編輯:黃飛
?
評(píng)論
查看更多