NAS最近也很火,正好看到了這篇論文,解讀一下,這篇論文是基于DAG(directed acyclic graph)的,DAG包含了上億的 sub-graphs, 為了防止全部遍歷這些模型,這篇論文設(shè)計(jì)了一種全新的采樣器,這種采樣器叫做Gradient-based search suing differential Architecture Sampler(GDAS),該采樣器可以自行學(xué)習(xí)和優(yōu)化,在這個(gè)的基礎(chǔ)上,在CIFAR-10上通過4 GPU hours就能找到一個(gè)最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)。
目前主流的NAS一般是基于進(jìn)化算法(EA)和強(qiáng)化學(xué)習(xí)(RL)來做的。EA通過權(quán)衡validation accuracy來決定是否需要移除一個(gè)模型,RL則是validation accuracy作為獎勵來優(yōu)化模型生成。作者認(rèn)為這兩種方法都很消耗計(jì)算資源。作者這篇論文中設(shè)計(jì)的GDAS方法可以在一個(gè)單v100 GPU上,用四小時(shí)搜索到一個(gè)優(yōu)秀模型。
GDAS
這個(gè)采用了搜索robust neural cell來替代搜索整個(gè)網(wǎng)絡(luò)。如下圖,不同的操作(操作用箭頭表示)會計(jì)算出不同的中間結(jié)果(中間結(jié)果用cycle表示),前面的中間結(jié)果會加起來闖到后面。
在優(yōu)化速度上,傳統(tǒng)的DAG存在一些問題:基于RL和EA的方法,需要獲得反饋都需要很長一段時(shí)間。而這篇論文提出的GDAS方法能夠利用梯度下降去做優(yōu)化,具體怎么梯的下面會說到。此外,使用GDAS的方法可以sample出sub-graph,這意味著計(jì)算量要比DAG的方法小很多。
絕大多數(shù)的NAS方法可以歸為兩類:Macro search和micro search
Macro search
顧名思義,實(shí)際上算法的目的是想要發(fā)現(xiàn)一個(gè)完整的網(wǎng)絡(luò)結(jié)構(gòu)。因此多會采用強(qiáng)化學(xué)習(xí)的方式。現(xiàn)有的方法很多都是使用Q-learning的方法來學(xué)習(xí)的。那么會存在的問題是,需要搜索的網(wǎng)絡(luò)數(shù)量會呈指數(shù)級增長。最后導(dǎo)致的結(jié)果就是網(wǎng)絡(luò)會更淺。
Micro Search
這種不是搜索整個(gè)神經(jīng)網(wǎng)絡(luò),而是搜索neural cells的方式。找到指定的neural cells后,再去堆疊。這種設(shè)計(jì)方式雖然能夠設(shè)計(jì)更深的網(wǎng)絡(luò),但是依舊要消耗很長時(shí)間,比如100GPU days,超長。這篇文章就是在消耗上面做優(yōu)化。
算法原理
DAG的搜索空間
前面也說了DAG是通過搜索所謂的neural cell而不是搜索整個(gè)網(wǎng)絡(luò)。每個(gè)cell由多個(gè)節(jié)點(diǎn)和節(jié)點(diǎn)間的激活函數(shù)構(gòu)成。節(jié)點(diǎn)我們用來表示,節(jié)點(diǎn)的計(jì)算如下圖。每個(gè)節(jié)點(diǎn)有其余兩個(gè)節(jié)點(diǎn)(下面公式中的節(jié)點(diǎn)i和節(jié)點(diǎn)j)來生成,而中間會從一個(gè)函數(shù)集合中去sample函數(shù)出來, 這個(gè)F數(shù)據(jù)集的組成是1)恒等映射 2)歸零 3)3x3 depthwise分離卷積 4)3x3 dilated depthwise 分離卷積 5)5x5 depthwise分離卷積 6)5x5 dilated depthwise 分離卷積。7)3x3平均池化 8) 3 x 3 最大池化。
那么生成節(jié)點(diǎn)I后,再去生成對應(yīng)的cell。我們將cell的節(jié)點(diǎn)數(shù)記為B,以B=4為例,該cell實(shí)際上會包括7個(gè)節(jié)點(diǎn),是前面兩層的cell的輸出(實(shí)際上也就是上面公式中的k和j),而則是我們(1)中計(jì)算出來的結(jié)果。也就是該cell的output tensor實(shí)際上是四個(gè)節(jié)點(diǎn)的output的聯(lián)結(jié)。
將cell組裝為網(wǎng)絡(luò)
剛剛上面的這種叫做normal cell,作者還設(shè)計(jì)了一個(gè)reduction cell, 用于下采樣。這個(gè)reduction cell就是手動設(shè)計(jì)的了,沒有像normal cell那樣復(fù)雜。normal cell 的步長為1,reduction cell步長為2, 最后的網(wǎng)絡(luò)實(shí)際上就是由這些cell組裝起來的。如下圖:
搜索模型參數(shù)
搭建的工作如上面所示,好像也還好,就像搭積木,這篇論文我覺得創(chuàng)新的地方在于它的搜索方法,特別是通過梯度下降的方式來更新參數(shù),很棒。具體的搜索參數(shù)環(huán)節(jié),它是這么做的:
首先我們的優(yōu)化目標(biāo)和手工設(shè)計(jì)的網(wǎng)絡(luò)別無二致,都是最大釋然估計(jì):
而上式中的Pr,實(shí)際上可寫成:
這個(gè)實(shí)際上是node i和node j的函數(shù)分布,k則是F的基數(shù)。而Node可以表示為:
是從中sample出來的,而
這個(gè)實(shí)際上是node i和node j的函數(shù)分布,k則是F的基數(shù)。而Node可以表示為:
其中是從離散分布中間sample出來的函數(shù)。這里問題來了,如果直接去優(yōu)化Pr,這里由于I是來自于一個(gè)離散分布,沒法對離散分布使用梯度下降方法。這里,作者使用了Gumbel-Max trick來解決離散分布中采樣不可微的問題,具體可以看這個(gè)問題下的回答
如何理解Gumbel-Max trick?
TL;DR: Gumbel Trick 是一種從離散分布取樣的方法,它的形式可以允許我們定義一種可微分的,離散分布的近似取樣,這種取樣方式不像「干脆以各類概率值的概率向量替代取樣」這么粗糙,也不像直接取樣一樣不可導(dǎo)(因此沒辦法應(yīng)對可能的 bp )。
于是這里將這個(gè)離散分布不可微的問題做了轉(zhuǎn)移,同時(shí)對應(yīng)的優(yōu)化目標(biāo)變?yōu)椋?/p>
這里有個(gè)的參數(shù),可以控制的相似程度。注意在前向傳播中我們使用的是等式(5), 而在反向傳播中,使用的是等式(7)。結(jié)合以上內(nèi)容,我們模型的loss是:
我們將最后學(xué)習(xí)到的網(wǎng)絡(luò)結(jié)構(gòu)稱為A,每一個(gè)節(jié)點(diǎn)由前面T個(gè)節(jié)點(diǎn)連接而來,在CNN中,我們把T設(shè)為2, 在RNN中,T設(shè)為1
在參數(shù)上,作者使用了SGD,學(xué)習(xí)率從0.025逐漸降到1e-3,使用的是cosine schedule。具體的參數(shù)和function F 設(shè)計(jì)上,可以去看看原論文。
總的來說,我覺得這篇論文最大的創(chuàng)新點(diǎn)是使用Gumbel-Max trick來使得搜索過程可微分,當(dāng)然它中間也使用了一些手動設(shè)計(jì)的模塊(如reduction cell),所以速度會比其余的NAS更快,之前我也沒有接觸過NAS, 看完這篇論文后對現(xiàn)在的NAS常用的方法以及未來NAS發(fā)展的趨勢還是有了更深的理解,推薦看看原文。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4771瀏覽量
100772 -
gpu
+關(guān)注
關(guān)注
28文章
4740瀏覽量
128951 -
強(qiáng)化學(xué)習(xí)
+關(guān)注
關(guān)注
4文章
266瀏覽量
11256
原文標(biāo)題:單v100 GPU,4小時(shí)搜索到一個(gè)魯棒的網(wǎng)絡(luò)結(jié)構(gòu)
文章出處:【微信號:rgznai100,微信公眾號:rgznai100】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論