編者按:反向傳播是一種訓(xùn)練人工神經(jīng)網(wǎng)絡(luò)的常見(jiàn)方法,它能簡(jiǎn)化深度模型在計(jì)算上的處理方式,是初學(xué)者必須熟練掌握的一種關(guān)鍵算法。對(duì)于現(xiàn)代神經(jīng)網(wǎng)絡(luò),通過(guò)反向傳播,我們能配合梯度下降大幅提高模型的訓(xùn)練速度,在一周時(shí)間內(nèi)就完成以往研究人員可能要耗費(fèi)兩萬(wàn)年才能完成的模型。
除了深度學(xué)習(xí),反向傳播算法在許多其他領(lǐng)域也是一個(gè)強(qiáng)大的計(jì)算工具,從天氣預(yù)報(bào)到分析數(shù)值穩(wěn)定性——區(qū)別只在于名稱差異。事實(shí)上,這種算法在幾十個(gè)不同的領(lǐng)域都有成熟應(yīng)用,無(wú)數(shù)研究人員都為這種“反向模式求導(dǎo)”的形式著迷。
從根本上說(shuō),無(wú)論是深度學(xué)習(xí)還是其他數(shù)值計(jì)算環(huán)境,這是一種方便快速計(jì)算的方法,也是一個(gè)必不可少的計(jì)算竅門(mén)。
計(jì)算圖
談及計(jì)算,有人可能又要為煩人的計(jì)算公式頭疼了,所以本文用了一種思考數(shù)學(xué)表達(dá)式的輕松方法——計(jì)算圖。以非常簡(jiǎn)單的e=(a+b)×(b+1)為例,從計(jì)算角度看它一共有3步操作:兩次求和和一次乘積。為了讓大家對(duì)計(jì)算圖有更清晰的理解,這里我們把它分開(kāi)計(jì)算,并繪制圖像。
我們可以把這個(gè)等式分成3個(gè)函數(shù):
在計(jì)算圖中,我們把每個(gè)函數(shù)連同輸入變量一起放進(jìn)節(jié)點(diǎn)中。如果當(dāng)前節(jié)點(diǎn)是另一個(gè)節(jié)點(diǎn)的輸入,用帶剪頭的線表示數(shù)據(jù)流向:
這其實(shí)是計(jì)算機(jī)科學(xué)中的一種常見(jiàn)描述方法,尤其是在討論涉及函數(shù)的程序時(shí),它非常有用。此外,現(xiàn)在流行的大多數(shù)深度學(xué)習(xí)開(kāi)源框架,比如TensorFlow、Caffe、CNTK、Theano等,都采用了計(jì)算圖。
仍以之前的例子為例,在計(jì)算圖中,我們可以通過(guò)設(shè)置輸入變量為特定值來(lái)計(jì)算表達(dá)式。如,我們?cè)O(shè)a=2,b=1:
可以得到e=(a+b)×(b+1)=6。
計(jì)算圖上的導(dǎo)數(shù)
如果要理解計(jì)算圖上的導(dǎo)數(shù),一個(gè)關(guān)鍵在于我們?nèi)绾卫斫饷恳粭l帶箭頭的線(下稱“邊”)上的導(dǎo)數(shù)。以之前的連接a節(jié)點(diǎn)和c=a+b節(jié)點(diǎn)的邊為例,如果a對(duì)c有影響,那這是個(gè)怎么樣的影響?如果a變化了,c會(huì)怎么變化?我們稱這為c關(guān)于a的偏導(dǎo)數(shù)。
為了計(jì)算圖中的偏導(dǎo)數(shù),我們先來(lái)復(fù)習(xí)這兩個(gè)求和規(guī)則和乘積規(guī)則:
已知a=2,b=1,那么相應(yīng)的計(jì)算圖就是:
現(xiàn)在我們計(jì)算出了相鄰兩個(gè)節(jié)點(diǎn)的偏導(dǎo)數(shù),如果我想知道不直接相連的節(jié)點(diǎn)是如何相互影響的,你會(huì)怎么辦?如果我們以速率為1的速度變化輸入a,那么根據(jù)偏導(dǎo)數(shù)可知,函數(shù)c的變化速率也是1,已知e相對(duì)于c的偏導(dǎo)數(shù)是2,那么同樣的,e相對(duì)a的變化速率也是2。
計(jì)算不直接相連節(jié)點(diǎn)之間偏導(dǎo)數(shù)的一般規(guī)則是計(jì)算各路徑偏導(dǎo)數(shù)的和,而同一路徑偏導(dǎo)數(shù)則是各邊偏導(dǎo)數(shù)的乘積,例如,e關(guān)于b的偏導(dǎo)數(shù)就等于:
上式表示了b是如何通過(guò)影響函數(shù)c和d來(lái)影響函數(shù)e的。
像這種一般的“路徑求和”規(guī)則只是對(duì)多元鏈?zhǔn)揭?guī)則的不同思考方式。
路徑分解
“路徑求和”的問(wèn)題在于,如果我們只是簡(jiǎn)單粗暴地計(jì)算每條可能路徑的偏導(dǎo)數(shù),我們很可能會(huì)最后得到一個(gè)“爆炸”的和。
如上圖所示,X到Y(jié)有3條路徑,Y到Z也有3條路徑,如果要計(jì)算?Z/?X,我們要計(jì)算的是3×3=9條路徑的偏導(dǎo)數(shù)的和:
這還只是9條,隨著模型變得越來(lái)越復(fù)雜,相應(yīng)的計(jì)算復(fù)雜度也會(huì)呈指數(shù)級(jí)上升。因此比起傻乎乎地一個(gè)個(gè)求和,我們最好能記起一些小學(xué)數(shù)學(xué)知識(shí),然后把上式轉(zhuǎn)為:
是不是很眼熟?這就是前向傳播算法和反向傳播算法中最基礎(chǔ)的一個(gè)偏導(dǎo)數(shù)等式。通過(guò)分解路徑,這個(gè)式子能更高效地計(jì)算總和,雖然長(zhǎng)得和求和等式有一定差異,但對(duì)于每條邊它確實(shí)只計(jì)算了一次。
前向模式求導(dǎo)從計(jì)算圖的輸入開(kāi)始,到最后結(jié)束。在每個(gè)節(jié)點(diǎn)上,它匯總了所有輸入的路徑,每條路徑代表輸入影響該節(jié)點(diǎn)的一種方式。相加后,我們就能得到輸入對(duì)最終結(jié)果的總的影響,也就是偏導(dǎo)數(shù)。
雖然你以前可能沒(méi)想過(guò)從計(jì)算圖的角度來(lái)進(jìn)行理解,但這樣一看,其實(shí)前向模式求導(dǎo)和我們剛開(kāi)始學(xué)微積分時(shí)接觸的內(nèi)容差不多。
另一方面,反向模式求導(dǎo)則是從計(jì)算圖的最后開(kāi)始,到輸入結(jié)束。對(duì)于每個(gè)節(jié)點(diǎn),它做的是合并所有源自該節(jié)點(diǎn)的路徑。
前向模式求導(dǎo)關(guān)注的是一個(gè)輸入如何影響每個(gè)節(jié)點(diǎn),反向模式求導(dǎo)關(guān)注的是每個(gè)節(jié)點(diǎn)如何影響最后那一個(gè)輸出。換句話說(shuō),就是前向模式求導(dǎo)是在把?/?X塞進(jìn)每個(gè)節(jié)點(diǎn),反向模式求導(dǎo)是在把?Z/?塞進(jìn)每個(gè)節(jié)點(diǎn)。
大功告成
說(shuō)到現(xiàn)在,你可能會(huì)想知道反向模式求導(dǎo)究竟有什么意義。它看起來(lái)就是前向模式求導(dǎo)的一個(gè)奇怪翻版,其中會(huì)有什么優(yōu)勢(shì)嗎?
讓我們從之前的那張計(jì)算圖開(kāi)始:
我們先用前向模式求導(dǎo)計(jì)算輸入b對(duì)各個(gè)節(jié)點(diǎn)的影響:
?e/?b=5。我們把這個(gè)放一邊,再來(lái)看看反向模式求導(dǎo)的情況:
之前我們說(shuō)反向模式求導(dǎo)關(guān)注的是每個(gè)節(jié)點(diǎn)如何影響最后那個(gè)輸出,根據(jù)上圖可以發(fā)現(xiàn),圖中偏導(dǎo)數(shù)既有?e/?b的,也有?e/?a的。這是因?yàn)檫@個(gè)模型有兩個(gè)輸入,而它們都對(duì)輸出e產(chǎn)生了影響。也就是說(shuō),反向模式求導(dǎo)更能反映全局輸入情況。
如果說(shuō)這是一個(gè)只有兩個(gè)輸入的簡(jiǎn)單例子,兩種方法都無(wú)所謂,那么請(qǐng)想象一個(gè)有一百萬(wàn)個(gè)輸入、只有一個(gè)輸出的模型。像這樣的模型,我們用前向模式求導(dǎo)要算一百萬(wàn)次,用反向模式求導(dǎo)只要算1次,這就高下立判了!
在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),我們把cost(描述網(wǎng)絡(luò)表現(xiàn)好壞的值)視作一個(gè)包含各類參數(shù)(描述網(wǎng)絡(luò)行為方式的數(shù)字)的函數(shù)。為了提升模型性能,我們要不斷改變參數(shù)對(duì)cost函數(shù)求導(dǎo),以此進(jìn)行梯度下降。模型的參數(shù)千千萬(wàn),但它的輸出只有一個(gè),因此機(jī)器學(xué)習(xí)對(duì)于反向模式求導(dǎo),也就是反向傳播算法來(lái)說(shuō)是個(gè)再適合不過(guò)的應(yīng)用領(lǐng)域。
那有沒(méi)有一種情況下,前向模式求導(dǎo)能比反向模式求導(dǎo)更好?有的!我們到現(xiàn)在談的都是多輸入單輸出的情形,這時(shí)反向更好;如果是一輸入多輸出、多輸入多輸出,前向模式求導(dǎo)速度更快!
這不是太普通了嗎?
當(dāng)我第一次真正理解反向傳播算法時(shí),我的反應(yīng)是:哦,就是最簡(jiǎn)單的鏈?zhǔn)椒▌t!我怎么花了這么久才明白?事實(shí)上我也不是唯一出現(xiàn)這種反應(yīng)的人,的確,如果問(wèn)題是你能從前向模式求導(dǎo)中推出那種更聰明的計(jì)算方法,這就沒(méi)那么麻煩了。
但我認(rèn)為這比看起來(lái)要困難得多。在反向傳播算法剛發(fā)明的時(shí)候,人們其實(shí)并沒(méi)有十分關(guān)注前饋神經(jīng)網(wǎng)絡(luò)的研究。所以也沒(méi)人發(fā)現(xiàn)它的衍生品有利于快速計(jì)算。但當(dāng)大家都知道這種衍生品的好處后,他們又開(kāi)始反應(yīng)過(guò)來(lái):原來(lái)它們有這樣的關(guān)系!這之中有一個(gè)惡性循環(huán)。
更糟糕的是,在腦子里推一推算法的衍生工具是很普遍的,一旦涉及用它們訓(xùn)練神經(jīng)網(wǎng)絡(luò),這幾乎就等同于洪水猛獸。你肯定會(huì)陷入局部最小值!你可能會(huì)浪費(fèi)巨大的計(jì)算成本!人們只有在確認(rèn)這種方法有效后,才會(huì)乖乖閉嘴去實(shí)踐。
小結(jié)
衍生工具比你想象中的更易于挖掘,也更好用,我希望這是本文為你帶來(lái)的主要經(jīng)驗(yàn)。雖然事實(shí)上這個(gè)挖掘過(guò)程并不容易,但在深度學(xué)習(xí)中領(lǐng)會(huì)這一點(diǎn)很重要,換一個(gè)角度,我們就能發(fā)現(xiàn)不同的風(fēng)景。同樣的話也適用于其他領(lǐng)域。
還有其他經(jīng)驗(yàn)嗎?我認(rèn)為有。
反向傳播算法也是了解數(shù)據(jù)流經(jīng)模型過(guò)程的有利“鏡頭”,我們能用它知道為什么有些模型會(huì)難以優(yōu)化,如經(jīng)典的遞歸神經(jīng)網(wǎng)絡(luò)中梯度消失的問(wèn)題。
最后,讀者可以嘗試同時(shí)結(jié)合前向傳播和反向傳播兩種算法來(lái)進(jìn)行更有效的計(jì)算。如果你真的理解了這兩種算法的技巧,你會(huì)發(fā)現(xiàn)其中會(huì)有不少有趣的衍生表達(dá)式。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4777瀏覽量
100973 -
計(jì)算圖
+關(guān)注
關(guān)注
0文章
9瀏覽量
6933 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5511瀏覽量
121355
原文標(biāo)題:計(jì)算圖演算:反向傳播
文章出處:【微信號(hào):jqr_AI,微信公眾號(hào):論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論