1 簡介
預(yù)訓(xùn)練模型BERT以及相關(guān)的變體自從問世以后基本占據(jù)了各大語言評(píng)測任務(wù)榜單,不斷刷新記錄,但是,BERT龐大的參數(shù)量所帶來的空間跟時(shí)間開銷限制了其在下游任務(wù)的廣泛應(yīng)用。基于此,人們希望能通過Bert得到一個(gè)更小規(guī)模的模型,同時(shí)基本具備Bert的能力,從而為下游任務(wù)的大規(guī)模應(yīng)用提供可能性。目前許多跟Bert相關(guān)的蒸餾方法被提出來,本章節(jié)就來分析下這若干蒸餾方法之間的細(xì)節(jié)以及差異。
知識(shí)蒸餾由兩個(gè)模型組成,teacher模型跟student模型,一般teacher模型規(guī)模跟參數(shù)量都比較龐大,所以能力更強(qiáng),而student模型規(guī)模比較小,如果直接訓(xùn)練的話效果比較有限,所以是先訓(xùn)練teacher模型,讓它學(xué)到充足的知識(shí),然后用student模型去學(xué)習(xí)teacher模型的行為,從而實(shí)現(xiàn)將知識(shí)從teacher模型轉(zhuǎn)移到student模型,使得student模型能在較小的參數(shù)量的同時(shí)具備接近大模型的能力。在蒸餾過程中,最常見的student模型部分的loss,就是對(duì)于同一個(gè)數(shù)據(jù),將teacher模型的預(yù)測的soft概率作為ground truth,讓teacher模型去學(xué)習(xí)從而預(yù)測得到相同的結(jié)果,這部分teacher模型跟student模型預(yù)測的概率之間距離就是蒸餾最常見的loss(通常是交叉熵)。蒸餾學(xué)習(xí)希望student模型學(xué)到teacher模型的能力,從而預(yù)測的結(jié)果跟teacher模型預(yù)測的soft概率足夠接近,也就是希望這部分的loss盡可能的小。
2 DualTrain+SharedProj
以往的知識(shí)蒸餾雖然可以有效的壓縮模型尺寸,但很難將teacher模型的能力蒸餾到一個(gè)更小詞表的student模型中,而DualTrain+SharedProj解決了這個(gè)難題。它主要針對(duì)Bert的詞表大小跟嵌入緯度做了縮簡,其余部分,包括模型結(jié)構(gòu)跟層數(shù)保持跟teacher模型(Bert Base)一致,從而實(shí)現(xiàn)將知識(shí)從teacher模型遷移到student模型中。
圖1: DualTrain+SharedProj框架
區(qū)別于其他蒸餾方法,DualTrain+SharedProj有兩個(gè)特別的地方,一個(gè)是Dual Training, 另一個(gè)是Shared Projection。Dual Training主要是為了解決teacher模型跟student模型不共用詞表的問題,在蒸餾過程中,對(duì)于teacher模型,會(huì)隨機(jī)選擇teacher模型或者student模型的詞表去進(jìn)行分詞,可以理解就是混合了teacher模型跟student模型的詞表,這種方式可以對(duì)齊兩個(gè)規(guī)模不同的的詞表。例如圖中左邊部分,I和machine用的是teacher模型的分詞結(jié)果而其余token用的是student模型的分詞結(jié)果。第二部分是Shared Projection,這部分很好理解,因?yàn)閟tudent模型嵌入層緯度縮小了,導(dǎo)致每個(gè)transformer層的緯度都縮小了,但是我們希望student模型跟teacher模型的transformer層的參數(shù)足夠接近,所以這里需要一個(gè)可訓(xùn)練的矩陣將兩個(gè)不同維度的transformer層參數(shù)縮放到同一個(gè)維度才能進(jìn)行比較。如果是對(duì)teacher模型的參數(shù)進(jìn)行縮放,就叫做down projection,如果是對(duì)student模型參數(shù)進(jìn)行的縮放,就叫做up projection。同時(shí),12層的transformer參數(shù)共用同一個(gè)縮放矩陣,所以叫做shared projection。例如下圖,下標(biāo)t,s分別代表teacher模型跟student模型。
圖2: up projection損失
圖3: DualTrain+SharedProj的損失函數(shù)
在蒸餾過程中,會(huì)將teacher模型跟student模型都在監(jiān)督數(shù)據(jù)上進(jìn)行訓(xùn)練,將兩個(gè)模型預(yù)測結(jié)果的損失加上兩個(gè)模型之間transformer層的參數(shù)之間的距離的損失作為最終損失,去更新student模型的參數(shù)。最終實(shí)驗(yàn)效果可也表明,隨著student模型的隱藏層緯度縮減得越厲害,模型的效果也會(huì)逐漸變差。
圖4: DualTrain+SharedProj的實(shí)驗(yàn)效果
DualTrain+SharedProj是很少見的student模型跟teacher模型不共享詞表的一種蒸餾方式,通過縮小詞表跟縮減嵌入層緯度,可以很大程度的減少模型的尺寸。同時(shí)也要注意,尺寸縮小得厲害,student模型的效果也下降地越厲害。另外有一點(diǎn)我不太理解,只通過一個(gè)dual training過程就可以對(duì)齊兩個(gè)詞表了嗎?是不是要蒸餾開始之前先對(duì)teacher模型,混合兩個(gè)詞表的分詞結(jié)果做下預(yù)訓(xùn)練會(huì)更加合理?
3DistillBERT
DistilBERT是通過一種比較常規(guī)的蒸餾方法得到的,它的teacher模型依舊是Bert Base,DistilBERT沿用了Bert的結(jié)構(gòu),但是transfromer層數(shù)只有6層(Bert Base有12層),同時(shí)還將嵌入層token-type embedding跟最后的pooling層移除。為了讓DistilBERT有一個(gè)更加合理的初始化,DistilBERT的transformer參數(shù)來源于Bert Base,每隔兩層transformer取其中一層的參數(shù)來作為DistilBERT的參數(shù)初始化。
在蒸餾過程中,除了常規(guī)的蒸餾部分的loss,還加入了一個(gè)自監(jiān)督訓(xùn)練的loss(MLM任務(wù)的loss),除此之外,實(shí)驗(yàn)還發(fā)現(xiàn)加入一個(gè)詞嵌入的loss有利于對(duì)齊teacher模型跟student模型的隱藏層表征。
DistilBERT是一種常見的通過蒸餾得到的方法,基本上是通過減少transformer的層數(shù)來減少模型尺寸,同時(shí)加速模型推理的。
4LSTM
蒸餾學(xué)習(xí)并不要求teacher模型跟student模型要隸屬于同一種模型架構(gòu),于是就有人腦洞大開,想用BiLSTM作為student模型來承載Bert Base龐大的能力。這里的teacher模型依舊是Bert Base,student模型分為三個(gè)部分,第一部分是詞嵌入層,第二部分是雙向LSTM+pooling,這里會(huì)將BiLSTM得到的隱藏層狀態(tài)通過max pooling生成句子的表征,第三部分是全連接層,直接輸出各個(gè)類別的概率。
在蒸餾開始之前,需要先在特定任務(wù)的監(jiān)督數(shù)據(jù)集上對(duì)teacher模型進(jìn)行微調(diào),因?yàn)槭欠诸惾蝿?wù),所以Bert Base跟后面的全連接層會(huì)一起更新參數(shù),從而讓teacher模型適配下游任務(wù)。在蒸餾過程中,student模型的損失分為三部分,第一部分依舊是常規(guī)的根據(jù)teacher模型預(yù)測的soft概率跟student模型預(yù)測的概率之間的交叉熵?fù)p失。第二部分是在監(jiān)督數(shù)據(jù)下student模型預(yù)測的結(jié)果跟真實(shí)標(biāo)簽結(jié)果之間的交叉熵?fù)p失。第三部分是teacher模型跟student模型生成表征之間的KL距離,也就是BiLSTM+pooling跟Bert base最后一層狀態(tài)輸出之間的距離,但是由于這兩者可能維度不一樣,所以這里也需要引入一個(gè)全連接層來縮放。
圖5: BiLSTM的蒸餾過程
圖6: BiLSTM蒸餾的效果對(duì)比
可以看得到通過蒸餾得到的BiLSTM明顯優(yōu)于直接finetune的,這里證明了蒸餾學(xué)習(xí)的有效性。除此之外,BiLSTM本身的準(zhǔn)確率就很高了,說明任務(wù)比較簡單(要不然蒸餾過后的BiLSTM準(zhǔn)確率比teacher模型Bert Base還高不是很詭異嘛?),所以并不能說明把Bert Base蒸餾到BiLSTM是個(gè)合適的選擇。LSTM本身結(jié)構(gòu)的局限性導(dǎo)致了很難完全學(xué)習(xí)到transformer的知識(shí)跟能力,筆者以前也在一些比較難的數(shù)據(jù)集上嘗試過類似的做法,但是最終作為student模型的LSTM的效果跟teacher模型的之間的差距還是比較大,并且泛化能力比較差。
5 PDK
PKD想通過蒸餾學(xué)習(xí)將Bert Base的transformer層數(shù)進(jìn)行壓縮,但是常規(guī)的方式只學(xué)習(xí)teacher模型最后一層的結(jié)果,雖然能在訓(xùn)練集上取得可以媲美teacher模型的效果,但是在測試集的表現(xiàn)很快就收斂了。這種現(xiàn)象看起來像是在訓(xùn)練集上過擬合了,從而影響了student模型的泛化能力?;诖?,PKD在原本的基礎(chǔ)上加上了新的約束項(xiàng),驅(qū)使student模型去學(xué)習(xí)模仿teacher模型的中間過程。具體的有兩種可能方式,第一種就是讓student模型去學(xué)習(xí)teacher模型transformer每隔幾層的結(jié)果,第二種是讓student模型去學(xué)習(xí)teacher模型最后幾層transformer的結(jié)果。
蒸餾過程的損失函數(shù)包括三個(gè)部分,第一部分還是常規(guī)的teacher模型預(yù)測的soft概率和student模型預(yù)測結(jié)果之間的交叉熵?fù)p失,第二部分是student模型預(yù)測概率跟真實(shí)標(biāo)簽之間的交叉熵?fù)p失,第三部分就是teacher模型跟student模型之間中間狀態(tài)的距離,這里用的[CLS]位置的表征。
6TinyBert
TinyBert的特別之處在于它的蒸餾過程分為兩個(gè)階段。第一階段是通用蒸餾,teacher model是預(yù)訓(xùn)練好的Bert, 可以幫助TinyBert學(xué)習(xí)到豐富的知識(shí),具備強(qiáng)大的通用能力,第二階段是特定任務(wù)蒸餾,teacher moder是經(jīng)過finetune的Bert, 使得TinyBert學(xué)習(xí)到特定任務(wù)下的知識(shí)。兩個(gè)蒸餾環(huán)節(jié)的設(shè)計(jì),能保證TinyBert強(qiáng)大的通用能力跟特定任務(wù)下的提升。
在每個(gè)蒸餾環(huán)節(jié)下,student模型的蒸餾分為三個(gè)部分,Embedding-layer Distillation,Transformer-layer Distillation, Prediction-layer Distillation。Embedding-layer Distillation是詞嵌入層的蒸餾,使得TinyBert更小維度的embedding輸出結(jié)果盡可能的接近Bert的embedding輸出結(jié)果。Transformer-layer Distillation是其中transformer層的蒸餾,這里的蒸餾采用的是隔k層蒸餾的方式。也就是,假如teacher model的Bert的transformer有12層,如果TinyBert的transformer設(shè)計(jì)有4層,那么就是就是每隔3層蒸餾,TinyBert的第1,2,3,4層transformer分別學(xué)習(xí)的是Bert的第3,6,9,12層transformer層的輸出。Prediction-layer Distillation主要是對(duì)齊TinyBert跟Bert在預(yù)測層的輸出,這里學(xué)習(xí)的是預(yù)測層的logit,也就是概率值。前面兩部分的損失都是MSE計(jì)算,因?yàn)閠eacher模型跟student模型在嵌入層跟隱藏層的維度不一致,所以這里需要相應(yīng)的線性映射將student模型的中間輸出映射到跟teacher 模型一樣的維度,最后一部分的損失是通過交叉熵?fù)p失計(jì)算的。通過這三部分的學(xué)習(xí),能保證TinyBert在中間層跟最后預(yù)測層都學(xué)習(xí)到Bert相應(yīng)的結(jié)果,進(jìn)而保證準(zhǔn)確率。
圖7: TinyBert框架
TinyBert的兩階段蒸餾過程能驅(qū)使student模型能學(xué)到teacher模型的通用知識(shí)和特定領(lǐng)域知識(shí),保證student模型在下游任務(wù)的表現(xiàn),是很值得借鑒的一種訓(xùn)練技巧。
7 MOBILEBERT
MOBILEBERT可能是目前性價(jià)比最高的一種蒸餾方式了(可能是筆者眼界有限),無論是從學(xué)習(xí)的目標(biāo),還是整個(gè)訓(xùn)練的方式,考慮都很周全。MOBILEBERT的student模型跟teacher模型的網(wǎng)絡(luò)層數(shù)保持一致,相關(guān)的模型結(jié)構(gòu)有所變化,首先是student模型跟teacher模型都新增了bottleneck,用于縮放內(nèi)部表示尺寸,在后面loss部分會(huì)展開介紹,其次是student模型里將FFN改成堆疊的FFN,最后是移除了layer normalization跟將激活函數(shù)由gelu換成relu.
在蒸餾過程中,student模型的損失包括兩個(gè)部分。第一個(gè)部分是student模型和teacher模型之間的feature map的距離,這里的feature map指的是每一層transformer輸出的結(jié)果。在這里,為了能讓student模型的隱藏層維度比teacher模型的隱藏層維度更小從而實(shí)現(xiàn)模型壓縮,這里的student模型跟teacher模型的transformer結(jié)構(gòu)都加入了bottleneck,也就是圖中綠色梯形的部分,通過這些bottleneck可以對(duì)文本表征尺寸進(jìn)行縮放,從而實(shí)現(xiàn)teacher模型跟student模型各自在每一個(gè)transformer內(nèi)部表示尺寸不同,但是輸入和輸出尺寸一致,所以就可能用內(nèi)部表示尺寸小的student模型去學(xué)習(xí)內(nèi)部表示尺寸大的teacher模型的能力跟知識(shí)。第二部分是兩個(gè)模型每一層transformer中attention的距離,這部分loss是為了利用self attention從teacher模型中學(xué)習(xí)到相關(guān)內(nèi)容從而更好得學(xué)習(xí)到第一部分的feature map。
圖8: MOBILEBERT相關(guān)的網(wǎng)絡(luò)結(jié)構(gòu)
MOBILEBERT的蒸餾過程是漸近式的,在蒸餾學(xué)習(xí)第L層的參數(shù)時(shí)會(huì)固定L層以下的參數(shù),一層一層的學(xué)習(xí)teacher模型的,直到學(xué)完全部層數(shù)。
圖9: MOBILEBERT的漸近式知識(shí)遷移過程
在完成蒸餾學(xué)習(xí)后,MOBILEBERT還會(huì)在做進(jìn)一步的預(yù)訓(xùn)練,預(yù)訓(xùn)練有三部分的loss,第一部分跟第二部分是BERT預(yù)訓(xùn)練的MLM跟NSP任務(wù)的loss,第三部分是teacher模型跟student模型在[MASK]位置的預(yù)測概率之間的交叉熵?fù)p失。
8總結(jié)
為了直觀的對(duì)比上面提及的蒸餾方法的壓縮效率和模型效果,我們匯總了若干種模型的具體信息以及在MRPC數(shù)據(jù)集上的表現(xiàn)。總體來說,有以下一些相關(guān)結(jié)論。
a)壓縮效率越高往往會(huì)伴隨著模型效果的持續(xù)下降。
b)Student模型的上限就是teacher模型。對(duì)于同一個(gè)student模型,并不是teacher模型越大student模型效果就會(huì)越好。因?yàn)樵酱蟮膖eacher模型,意味著更大的壓縮效率,也意味著更嚴(yán)重的性能下降。
c)只學(xué)習(xí)teacher模型最后的預(yù)測的soft概率是遠(yuǎn)遠(yuǎn)不夠的,需要對(duì)teacher模型中間的表征或者參數(shù)也進(jìn)行學(xué)習(xí),才能進(jìn)一步保證student模型的效果。
d)縮減transformer層數(shù)或者縮減隱藏層狀態(tài)緯度都可以壓縮模型,對(duì)于縮減隱藏層狀態(tài)維度,用MOBILEBERT那種bottleneck的方式優(yōu)于常規(guī)的通過一個(gè)額外的映射來對(duì)齊模型尺寸的方式??s減隱藏層狀態(tài)維度的方式的模型壓縮效率的上限更高。
e)漸進(jìn)性學(xué)習(xí)方式是有效的。也就是固定下層的參數(shù),只更新當(dāng)前層的參數(shù),依次迭代直至更新完student模型全部層。
f)分階段蒸餾是有效的。先學(xué)習(xí)通用的teacher模型,然后再學(xué)習(xí)特定任務(wù)下finetune的teacher模型。
g)跨模型結(jié)構(gòu)的蒸餾是有效的。用BiLSTM來學(xué)習(xí)Bert Base的能力比直接finetune BiLSTM的效果要好。
Model | type | Compress Factor | MRPC(f1) |
Bert Base | 1 | 88.9 | |
DualTrain+SharedProjUp |
192 96 48 |
5.74 19.41 61.94 |
84.9 84.9 79.3 |
DistilBERT | 1.67 | 87.5 | |
PKD |
6 3 |
1.64 2.40 |
85.0 80.7 |
TinyBert | 4 | 7.50 | 86.4 |
MOBILEBERT | 4.30 | 88.8 |
參考文獻(xiàn)
1.(2020) EXTREME LANGUAGE MODEL COMPRESSION WITH OPTIMAL SUBWORDS AND SHARED PROJECTIONS
https://openreview.net/pdf?id=S1x6ueSKPr
2. (2020) DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter
https://arxiv.org/abs/1910.01108
3. (2020) DISTILLING BERT INTO SIMPLE NEURAL NETWORKS WITH UNLABELED TRANSFER DATA
https://arxiv.org/pdf/1910.01769.pdf
4. (2019) Patient Knowledge Distillation for BERT Model Compression
https://arxiv.org/pdf/1908.09355.pdf
5.(2020)TINYBERT: DISTILLING BERT FOR NATURAL LAN- GUAGE UNDERSTANDING
https://openreview.net/attachment?id=rJx0Q6EFPB&name=original_pdf
6. (2020) MOBILEBERT: TASK-AGNOSTIC COMPRESSION OF BERT BY PROGRESSIVE KNOWLEDGE TRANSFER
https://openreview.net/pdf?id=SJxjVaNKwB
審核編輯 :李倩
-
模型
+關(guān)注
關(guān)注
1文章
3255瀏覽量
48898 -
LSTM
+關(guān)注
關(guān)注
0文章
59瀏覽量
3767
原文標(biāo)題:Bert系列之知識(shí)蒸餾
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論