引言
學(xué)過化學(xué)的都知道蒸餾這個概念,就是利用不同組分的沸點不同,將不同組分從混合液中分離出來。知識蒸餾用于網(wǎng)絡(luò)壓縮,也具有類似的性質(zhì)。具體的講,有一個大的神經(jīng)網(wǎng)絡(luò)充當(dāng)了“老師”的角色,她將書本上的知識先經(jīng)過自己的轉(zhuǎn)化和吸收,然后再傳授給“學(xué)生”網(wǎng)絡(luò)。學(xué)生網(wǎng)絡(luò)模型相對較小,但是經(jīng)過老師將知識提取教授,也可以實現(xiàn)大網(wǎng)絡(luò)的功能。
知識蒸餾的方法是大名鼎鼎的Hinton提出的,這種方法實現(xiàn)了大網(wǎng)絡(luò)向小網(wǎng)絡(luò)的知識遷移,使得應(yīng)用場景可以擴展到移動端。接下來我們具體看看知識蒸餾的整個過程。
1
原理
表面上看,大網(wǎng)絡(luò)應(yīng)該有更好的表達能力,或者說泛化能力。而小網(wǎng)絡(luò)節(jié)點數(shù)量和大網(wǎng)絡(luò)還有很大的差距,它如何能夠做到逼近大網(wǎng)絡(luò)的結(jié)果呢?首先,這與具體的應(yīng)用場景范圍有關(guān),在一定的場景下,小網(wǎng)絡(luò)可以接近大網(wǎng)絡(luò)的分類能力。這就好像對于某個更復(fù)雜的函數(shù),當(dāng)限定某個值域的時候,可以用一些簡單函數(shù)來逼近。其次,網(wǎng)絡(luò)分類器最終的結(jié)果是用概率來表示的,分類結(jié)果取決于概率最大的。因此最大概率是90%和最大概率是60%的最終分類結(jié)果是一樣的,這點就給了小網(wǎng)絡(luò)更靈活的表達方式。最后就是小網(wǎng)絡(luò)逼近大網(wǎng)絡(luò)的程度和大網(wǎng)絡(luò)的冗余程度有關(guān),這類似于對大網(wǎng)絡(luò)實行剪枝的結(jié)果。
那么如何訓(xùn)練一個小網(wǎng)絡(luò)呢?我們可以先考慮一下在數(shù)值分析中,用一個函數(shù)S(x)來逼近另外一個函數(shù)f(x),那么就可以通過最小化這兩個函數(shù)在每個點的平方和來實現(xiàn)。同理,訓(xùn)練小的網(wǎng)絡(luò)也必須使用大網(wǎng)絡(luò)的輸入和輸出作為訓(xùn)練集,而不能再使用訓(xùn)練大網(wǎng)絡(luò)的訓(xùn)練集了。原始訓(xùn)練集的標(biāo)注結(jié)果是絕對的(是和不是:1,0),而大網(wǎng)絡(luò)的輸出結(jié)果是一個概率向量,其包含了每一類的概率大小。這個結(jié)果不再僅僅只含有原始訓(xùn)練集的信息,它還包含了大網(wǎng)絡(luò)的信息。比如在原始圖片中,一張貓的圖片結(jié)果只有一個,但是經(jīng)過大網(wǎng)絡(luò)后,不僅僅有貓的結(jié)果,還有狗,房子,樹等每個類別的概率結(jié)果。其他類別的概率實際上告訴了我們不同類別之間存在的差異和共性,比如一張貓的圖片中是狗的概率可能就比是房子的概率大,因為貓和狗相對于貓和房子有更大的共性。
神經(jīng)網(wǎng)絡(luò)通常使用softmax函數(shù)來生成分類概率,這個函數(shù)形式為:
其中T是溫度,通常設(shè)置為1。使用較高的T可以產(chǎn)生更加softer的概率分布。更softer的概率分布提高網(wǎng)絡(luò)的泛化能力,有利于小網(wǎng)絡(luò)的訓(xùn)練。
寫到這里小編對softmax函數(shù)感到好奇,為什么神經(jīng)網(wǎng)絡(luò)都采用softmax來進行概率計算呢?學(xué)過熱力學(xué)的會發(fā)現(xiàn),這個softmax函數(shù)非常類似不同能級上粒子分布概率,位于能級E的粒子分布概率就是正比于:
而且溫度越高高能級粒子概率也越大,這與softmax函數(shù)也有同樣的結(jié)果。其實觀察他們的推導(dǎo)過程就會發(fā)現(xiàn),它們之所以有相同的形式來自于它們都是多分類問題,而且概率模型都屬于廣義線性模型。Softmax函數(shù)正是在廣義線性函數(shù)的假設(shè)上推導(dǎo)出來的。現(xiàn)在我們給出其傳統(tǒng)推導(dǎo),和基于熱力學(xué)統(tǒng)計的推導(dǎo)方法。
首先看什么是廣義線性模型,廣義線性模型是用于處理條件概率的一個基本模型,很多常見的分布模型(伯努利,高斯等)都屬于廣義線性模型。定義線性預(yù)測算子:
定義y基于x的條件概率分布,這個分布就是廣義線性模型:
分類問題就是求在給定輸入x的條件下,估計y值,即y屬于哪個類的問題??梢酝ㄟ^期望值來作為y的估計。容易得到這個期望值為:
因此一旦知道y的概率分布就知道了y的估計。這個估計就是回歸函數(shù)。現(xiàn)在我們來看softmax的傳統(tǒng)推導(dǎo)。
Y有多個可能的分類:
每種分類對應(yīng)著概率:
定義:
其中有:
于是得到廣義分布:
其中有,
然后可以求出:
求得估計值:
這就是softmax函數(shù)。
現(xiàn)在我們從統(tǒng)計熱力學(xué)角度來推導(dǎo)softmax函數(shù)。
神經(jīng)網(wǎng)絡(luò)的作用是對輸入進行特征提取,我們可以把這個提取過程表示為:
現(xiàn)在我們需要來理解E_i,這個應(yīng)該是表示從屬于特征i的程度,我們可以選擇一定函數(shù)f(E_i)來作為評價屬于特征i的程度?,F(xiàn)在我們假設(shè)特征1到k是可以涵蓋所有輸入的,即任何輸入都是由這些特征構(gòu)成的,特征值反應(yīng)了輸入屬于某個特征的量,那么所有這些特征的量之和應(yīng)該是所有輸入量的和,那么我們可以有:
我們現(xiàn)在需要求y屬于這個特征的概率,即:
現(xiàn)在我們假設(shè)有N個數(shù),這些數(shù)要分配不同的y值。這些數(shù)被分配是完全隨機的,但是受到每種y值的數(shù)量限制,對應(yīng)E_i的數(shù)量為N_i。那么將這N個數(shù)分配給k個不同類的分配方式可以得到:
我們來最大化W,即求最大似然函數(shù):
滿足約束條件:
我們利用拉格朗日對偶原理來求解極值:
我們可以得到類似玻爾茲曼分布的公式:
其中u就是溫度1/T。
現(xiàn)在回到正題,過于softer的代價函數(shù)可能會造成分類結(jié)果錯誤率低,為了平衡分類錯誤和小模型泛化能力,hinton提出使用兩個代價函數(shù)來進行訓(xùn)練,一個是T值較大,另外一個是T值為1。通過調(diào)節(jié)這兩個代價函數(shù)的比例來獲得滿意的訓(xùn)練結(jié)果。
2
實驗結(jié)果
Hinton的論文中分別在MINIST,語音識別上進行了實驗。我們僅僅看一下實驗結(jié)果,對知識蒸餾效果有個簡單印象。更深入的理解離不開實踐,只有真正去寫代碼去看結(jié)果,才能不會紙上談兵。
1) MINIST
大網(wǎng)絡(luò)含有2個隱含層,1200個激活單元,60000個訓(xùn)練集圖片。作者通過剪枝來將大網(wǎng)絡(luò)減小到只有800個激活單元,將溫度增加到20,相比于沒有regularization會減小很大錯誤率。
2) 語音識別
這里作者使用多個小網(wǎng)絡(luò)集合來作為教師網(wǎng)絡(luò),然后單個網(wǎng)絡(luò)作為學(xué)生網(wǎng)絡(luò)。每個網(wǎng)絡(luò)為8個隱含層,2560個激活單元,訓(xùn)練集有14000個標(biāo)注數(shù)據(jù)。結(jié)果如下:
其中WER為錯誤率。
總結(jié)
本文介紹了網(wǎng)絡(luò)壓縮算法,知識蒸餾。很多是小編個人理解,如有不同意見歡迎指正交流。更多可以參考hinton大神的知識蒸餾文獻。
-
算法
+關(guān)注
關(guān)注
23文章
4622瀏覽量
93101 -
函數(shù)
+關(guān)注
關(guān)注
3文章
4340瀏覽量
62793 -
網(wǎng)絡(luò)節(jié)點
+關(guān)注
關(guān)注
0文章
54瀏覽量
15930
原文標(biāo)題:【網(wǎng)絡(luò)壓縮三】知識蒸餾
文章出處:【微信號:FPGA-EETrend,微信公眾號:FPGA開發(fā)圈】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論