今天跟大家分享一篇來(lái)自CMU等機(jī)構(gòu)的論文《Sliced Recursive Transformer》,該論文已被 ECCV 2022 接收。
目前 vision transformer 在不同視覺(jué)任務(wù)上如分類(lèi)、檢測(cè)等都展示出了強(qiáng)大的性能,但是其巨大的參數(shù)量和計(jì)算量阻礙了該模型進(jìn)一步在實(shí)際場(chǎng)景中的應(yīng)用?;谶@個(gè)考慮,本文重點(diǎn)研究了如何在不增加額外參數(shù)量的前提下把模型的表達(dá)能力挖掘到極致,同時(shí)還要保證模型計(jì)算量在合理范圍內(nèi),從而可以在一些存儲(chǔ)容量小,計(jì)算能力弱的嵌入式設(shè)備上部署。
基于這個(gè)動(dòng)機(jī),Zhiqiang Shen、邢波等研究者提出了一個(gè) SReT 模型,通過(guò)循環(huán)遞歸結(jié)構(gòu)來(lái)強(qiáng)化每個(gè) block 的特征表達(dá)能力,同時(shí)又提出使用多個(gè)局部 group self-attention 來(lái)近似 vanilla global self-attention,在顯著降低計(jì)算量 FLOPs 的同時(shí),模型沒(méi)有精度的損失。
論文地址:https://arxiv.org/abs/2111.05297
代碼和模型:https://github.com/szq0214/SReT
總結(jié)而言,本文主要有以下兩個(gè)創(chuàng)新點(diǎn):
使用類(lèi)似 RNN 里面的遞歸結(jié)構(gòu)(recursive block)來(lái)構(gòu)建 ViT 主體,參數(shù)量不漲的前提下提升模型表達(dá)能力;
使用 CNN 中 group-conv 類(lèi)似的 group self-attention 來(lái)降低 FLOPs 的同時(shí)保持模型的高精度;
此外,本文還有其他一些小的改動(dòng):
網(wǎng)絡(luò)最前面使用三層連續(xù)卷積,卷積核為 3x3,結(jié)構(gòu)直接使用了研究者之前 DSOD 里面的 stem 結(jié)構(gòu);
Knowledge distillation 只使用了單獨(dú)的 soft label,而不是 DeiT 里面 hard 形式的 label 加 one-hot ground-truth,因?yàn)檠芯空哒J(rèn)為 soft label 包含的信息更多,更有利于知識(shí)蒸餾;
使用可學(xué)習(xí)的 residual connection 來(lái)提升模型表達(dá)能力;
如下圖所示,本文所提出的模型在參數(shù)量(Params)和計(jì)算量(FLOPs)方面相比其他模型都有明顯的優(yōu)勢(shì):
下面我們來(lái)解讀這篇文章: 1.ViT 中的遞歸模塊 遞歸操作的基本組成模塊如下圖:
該模塊非常簡(jiǎn)單明了,類(lèi)似于 RNN 結(jié)構(gòu),將模塊當(dāng)前 step 的輸出作為下個(gè) step 的輸入重新輸進(jìn)該模塊,從而增強(qiáng)模型特征表達(dá)能力。 研究者展示了將該設(shè)計(jì)直接應(yīng)用在 DeiT 上的結(jié)果,如下所示:
可以看到在加入額外一次簡(jiǎn)單遞歸操作之后就可以得到將近 2% 的精度提升。 當(dāng)然具體到全局網(wǎng)絡(luò)結(jié)構(gòu)層面還有不同的遞歸構(gòu)建方法,如下圖:
其中 NLL 層(Non-linear Projection Layer)是用來(lái)保證每個(gè)遞歸模塊輸入輸出不完全一致。論文提出使用這個(gè)模塊的主要原因是發(fā)現(xiàn)在上述 Table 1 里面更多次數(shù)的遞歸操作并沒(méi)有進(jìn)一步提升性能,說(shuō)明網(wǎng)絡(luò)可能學(xué)到了一個(gè)比較簡(jiǎn)單的狀態(tài),而 NLL 層可以強(qiáng)制模型輸入輸出不一致從而緩解這種情況。同時(shí),研究者從實(shí)驗(yàn)結(jié)果發(fā)現(xiàn)上圖 (1) internal loop 相比 external loop 設(shè)計(jì)擁有更好的 accuracy-FLOPs 結(jié)果。 2. 分組的 Group Self-attention 模塊 如下圖所示,研究者提出了一種分組的 group self-attention 策略來(lái)降低模型的 FLOPs,同時(shí)保證 self-attention 的全局注意力,從而使得模型沒(méi)有明顯精度損失:
Group Self-attention 模塊具體形式如下:
Group self-attention 的缺點(diǎn)是只有局部區(qū)域會(huì)相互作用,研究者提出通過(guò)使用 Permutation 操作來(lái)近似全局 self-attention 的機(jī)制,同時(shí)通過(guò) Inverse Permutation 來(lái)復(fù)原和保留 tokens 的次序信息,針對(duì)這個(gè)部分的消融實(shí)驗(yàn)如下所示:
其中 P 表示加入 Permutation,I 表示加入 Inverse Permutation,-L 表示如果 group 數(shù)為 1,就不使用 P 和 I(比如模型最后一個(gè) stage)。根據(jù)上述表格的結(jié)果,研究者最后采用了 [8, 2][4,1][1,1] 這種分組設(shè)計(jì)。 3. 其他設(shè)計(jì) 可學(xué)習(xí)的殘差結(jié)構(gòu) (LRC):
研究者嘗試了上圖三種結(jié)構(gòu),圖(3)結(jié)果最佳。具體而言,研究者在每個(gè)模塊里面添加了 6 個(gè)額外參數(shù)(4+2,2 個(gè)在 NLL 層),這些參數(shù)會(huì)跟模型其他參數(shù)一起學(xué)習(xí),從而使網(wǎng)絡(luò)擁有更強(qiáng)的表達(dá)能力,參數(shù)初始化都為 1,在訓(xùn)練過(guò)程 6 個(gè)參數(shù)的數(shù)值變化情況如下所示:
Stem 結(jié)構(gòu)組成:
如上表所示,Stem 由三個(gè) 3x3 的連續(xù)卷積組成,每個(gè)卷積 stride 為 2。 整體網(wǎng)絡(luò)結(jié)構(gòu): 研究者進(jìn)一步去掉了 class token 和 distillation token,并且發(fā)現(xiàn)精度有少量提升。
消融實(shí)驗(yàn):
模型混合深度訓(xùn)練: 研究者進(jìn)一步發(fā)現(xiàn)分組遞歸設(shè)計(jì)還有一個(gè)好處就是:可以支持模型混合深度訓(xùn)練,這種訓(xùn)練方式可以大大降低深度網(wǎng)絡(luò)結(jié)構(gòu)優(yōu)化復(fù)雜度,研究者展示了 108 層不同模型結(jié)構(gòu)優(yōu)化過(guò)程的 landscape 可視化,如下圖所示,可以很明顯的看到混合深度結(jié)構(gòu)優(yōu)化過(guò)程困難程度顯著低于另外兩種結(jié)構(gòu)。
最后,分組 group self-attention 算法 PyTorch 偽代碼如下:
審核編輯 :李倩
-
模型
+關(guān)注
關(guān)注
1文章
3254瀏覽量
48880 -
遞歸
+關(guān)注
關(guān)注
0文章
28瀏覽量
9032 -
cnn
+關(guān)注
關(guān)注
3文章
352瀏覽量
22237
原文標(biāo)題:ECCV 2022 | 視覺(jué)Transformer上進(jìn)行遞歸!SReT:不增參數(shù),計(jì)算量還少!
文章出處:【微信號(hào):CVer,微信公眾號(hào):CVer】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論