作者:Winnie
今天為大家介紹一個(gè)新技術(shù)—Medusa,它旨在加速大型語(yǔ)言模型(LLM)的生成。盡管其設(shè)計(jì)簡(jiǎn)單,但 Medusa能夠?qū)LM的生成效率提高約2倍。讓我們看看它是怎么做到的吧!
為什么LLM生成低效?
LLM在生成時(shí)的效率問(wèn)題主要是由內(nèi)存讀/寫(xiě)操作帶來(lái)的延遲,而這個(gè)問(wèn)題源自自動(dòng)回歸解碼過(guò)程的順序性特點(diǎn)。每次的前向傳播都需要頻繁地移動(dòng)模型參數(shù),盡管這只產(chǎn)生一個(gè)結(jié)果,但卻沒(méi)有完全利用現(xiàn)代硬件的計(jì)算潛能。傳統(tǒng)的解決方式(如增大批次大?。┰贚LM的場(chǎng)景下卻不再適用,因?yàn)檫@不僅會(huì)增加延遲,還會(huì)引發(fā)內(nèi)存問(wèn)題。
不僅如此,這種低效還帶來(lái)了額外的生成成本。例如,GPT-4的生成成本比僅僅處理prompt高了兩倍,Claude2則大約高出3倍。因此,加速LLM的低效生成是一個(gè)亟待解決的問(wèn)題。
Medusa來(lái)了!
面對(duì)推測(cè)性解碼的復(fù)雜性,研究人員推出了Medusa技術(shù),這個(gè)框架回歸了Transformer模型的本質(zhì),減少了復(fù)雜度,增強(qiáng)了效率,讓每個(gè)生成階段都能快速產(chǎn)出結(jié)果。當(dāng)將Medusa與基于樹(shù)的注意機(jī)制結(jié)合時(shí),生成速度提高了2到3倍。
接下來(lái),讓我們看一看Msdusa都做了哪些改進(jìn)吧!
Medusa總體框架
Medusa的核心在于它在LLM的最后隱藏層上增加的多個(gè)Heads,使它們并行工作,預(yù)測(cè)接下來(lái)的內(nèi)容。
當(dāng)將Medusa Heads加入模型時(shí),你會(huì)發(fā)現(xiàn),原始模型保持不變,而只有這些Medusa Heads進(jìn)行微調(diào)。在真正使用時(shí),每個(gè)Medusa Head都會(huì)為其位置產(chǎn)生預(yù)測(cè),這些預(yù)測(cè)會(huì)被組合、處理,最終給出最佳結(jié)果。
通過(guò)同時(shí)接受更多的tokens來(lái)增強(qiáng)解碼過(guò)程的效率,從而減少了所需的解碼步驟數(shù)量。
Medusa Heads
Medusa Heads與原有的語(yǔ)言模型頭相似,但卻擁有一個(gè)獨(dú)特的優(yōu)勢(shì):它們可以預(yù)測(cè)多個(gè)即將出現(xiàn)的token,而不僅僅是下一個(gè)。這種方法從Blockwise Parallel Decoding方法中汲取靈感,將每個(gè)Medusa頭設(shè)計(jì)為一個(gè)單層的前饋網(wǎng)絡(luò),且增強(qiáng)了殘差連接。
訓(xùn)練這些Medusa Heads非常方便!你可以使用用于訓(xùn)練原始模型的同一語(yǔ)料庫(kù),或者使用模型本身生成一個(gè)新的語(yǔ)料庫(kù)來(lái)訓(xùn)練它們。在訓(xùn)練階段,原始的模型保持靜態(tài),僅Medusa Heads進(jìn)行微調(diào)。這種有針對(duì)性的訓(xùn)練產(chǎn)生了一個(gè)參數(shù)效率極高的過(guò)程,可以迅速實(shí)現(xiàn)收斂—尤其是與speculative decoding方法中訓(xùn)練單獨(dú)的draft model的計(jì)算密集度相比。
Medusa Heads的表現(xiàn)相當(dāng)出色,它在預(yù)測(cè)“下一個(gè)”token時(shí)的top-1準(zhǔn)確率約為60%。但這僅僅是個(gè)開(kāi)始,它還有很大的提升空間。
通過(guò)Medusa Heads的測(cè)試,研究人員發(fā)現(xiàn):雖然預(yù)測(cè)“下下一個(gè)”token的top-1準(zhǔn)確率僅約為60%,但top-5準(zhǔn)確率卻飆升至超過(guò)80%。這一顯著的提高表明,如果我們可以巧妙地利用Medusa Heads做出的多個(gè)top排名預(yù)測(cè),就可以顯著增加每個(gè)解碼步驟生成的tokens數(shù)量。
實(shí)現(xiàn)這一目標(biāo)的方式是首先構(gòu)造一個(gè)候選集,這個(gè)集合由每個(gè)Medusa Head的預(yù)測(cè)結(jié)果的笛卡爾積形成。然后,依賴(lài)圖被編碼到注意力機(jī)制中,允許多個(gè)候選項(xiàng)目并行處理,這是受到圖神經(jīng)網(wǎng)絡(luò)思想的啟發(fā)。例如,在一個(gè)實(shí)際應(yīng)用中,可以從第一個(gè)Medusa頭部獲取前兩個(gè)預(yù)測(cè),從第二個(gè)頭部獲取前三個(gè)預(yù)測(cè),并將它們組合成一個(gè)多層樹(shù)結(jié)構(gòu)。在這種結(jié)構(gòu)中,一個(gè)注意力掩碼被實(shí)施,僅限制注意力于一個(gè)token的前一個(gè)token,從而保持歷史上下文。通過(guò)這種方式,可以同時(shí)處理多個(gè)候選項(xiàng),而無(wú)需增加批次大小。
下圖是Tree attention機(jī)制用于并行處理多個(gè)候選項(xiàng)目的一個(gè)可視化示例。在一個(gè)示例中,來(lái)自第一個(gè)Medusa頭部的前兩個(gè)預(yù)測(cè)和來(lái)自第二個(gè)頭部的前三個(gè)預(yù)測(cè)產(chǎn)生了2*3=6個(gè)候選項(xiàng)。這些候選項(xiàng)中的每一個(gè)都對(duì)應(yīng)于樹(shù)結(jié)構(gòu)中的一個(gè)不同分支。為了保證每個(gè)token只能訪(fǎng)問(wèn)其前面的token,注意力掩碼,該掩碼僅允許注意力從當(dāng)前token流向其前面的token。位置編碼的位置指數(shù)將根據(jù)這種結(jié)構(gòu)進(jìn)行調(diào)整。通過(guò)這種方式,可以確保歷史上下文的完整性和連貫性,同時(shí)提高解碼步驟的效率和準(zhǔn)確性。
值得注意的是,與一些獨(dú)立的研究相比,該方法傾向于使用簡(jiǎn)化形式的樹(shù)狀注意力,其中樹(shù)的模式在推斷期間是規(guī)則和固定的,這允許預(yù)處理樹(shù)狀注意力掩碼,進(jìn)而提高效率。通過(guò)創(chuàng)新這種解碼方法,它不僅提供了一個(gè)新的解碼路徑,而且為更精確和高效的未來(lái)預(yù)測(cè)打開(kāi)了新的可能性。
在早期關(guān)于投機(jī)解碼的研究中,重要性采樣技術(shù)用于產(chǎn)生與原始模型預(yù)測(cè)緊密相符的多樣化輸出。但隨后的研究表明,隨著“creativity dial”或采樣溫度的增加,這種方法的效率會(huì)降低。簡(jiǎn)而言之,如果一個(gè)draft model與原始模型一樣優(yōu)秀,理論上應(yīng)接受其所有輸出,使過(guò)程極為高效。但是重要性采樣可能會(huì)在某個(gè)階段拒絕這種方案。
實(shí)際上,人們常常僅調(diào)整采樣溫度來(lái)控制模型的創(chuàng)造力,而不是嚴(yán)格匹配原始模型的分布。那么為什么不只是接受看似合理的候選項(xiàng)呢?Typical acceptance策略受到截?cái)嗖蓸拥膯l(fā),目的是選取根據(jù)原始模型被視為足夠可能的候選項(xiàng)。通過(guò)設(shè)置基于原始模型預(yù)測(cè)概率的閾值,如果候選項(xiàng)超過(guò)這個(gè)閾值,則將其接受。
在技術(shù)語(yǔ)言中,我們采用硬閾值和依賴(lài)于熵的閾值中的最小值來(lái)決定是否接受一個(gè)候選項(xiàng),如截?cái)嗖蓸又兴?。這確保在解碼期間選擇了有意義的標(biāo)記和合理的延續(xù)。第一個(gè)標(biāo)記總是通過(guò)貪婪解碼被接受,確保每一步至少生成一個(gè)標(biāo)記。最終輸出是通過(guò)接受測(cè)試的最長(zhǎng)序列。這種方法的優(yōu)點(diǎn)在于其適應(yīng)性。如果將采樣溫度設(shè)置為零,它將簡(jiǎn)單地退化為最有效的形式——貪婪解碼。提高溫度會(huì)使方法變得更加高效,允許更長(zhǎng)的接受序列,這一點(diǎn)已通過(guò)嚴(yán)格測(cè)試得到驗(yàn)證。
性能測(cè)試
在Vicuna模型上測(cè)試了Medusa,這些模型是特別為聊天應(yīng)用優(yōu)化和調(diào)整的羊駝模型,其大小不同,參數(shù)數(shù)量分別為7B、13B和33B。目標(biāo)是衡量Medusa在現(xiàn)實(shí)世界的聊天機(jī)器人環(huán)境中能夠多大程度上加速這些模型的運(yùn)行。
訓(xùn)練Medusa頭部選擇了簡(jiǎn)單的方式,使用了公開(kāi)的ShareGPT數(shù)據(jù)集,這是最初用于訓(xùn)練Vicuna模型的數(shù)據(jù)的一個(gè)子集,只進(jìn)行了一個(gè)時(shí)代的訓(xùn)練。
這里的重點(diǎn)是——整個(gè)訓(xùn)練過(guò)程可以在幾小時(shí)到一天之內(nèi)完成,具體取決于模型的大小,全部在單個(gè)A100-80G GPU上完成。顯著的是,Medusa可以與量化基模型輕松結(jié)合,從而減少內(nèi)存需求。為了利用這一優(yōu)勢(shì),在訓(xùn)練33B模型時(shí)使用了8位量化。
為模擬現(xiàn)實(shí)環(huán)境,采用了MT測(cè)試臺(tái)進(jìn)行評(píng)估。結(jié)果是令人鼓舞的:Medusa借助其簡(jiǎn)單的設(shè)計(jì),在各種用例中穩(wěn)定實(shí)現(xiàn)了約2倍的實(shí)際運(yùn)行時(shí)間加速。顯著的是,有了Medusa的優(yōu)化,33B參數(shù)的Vicuna模型可以與13B模型一樣快速運(yùn)行。
結(jié)語(yǔ)
Medusa技術(shù)致力于通過(guò)多層頭部預(yù)測(cè)方法來(lái)加速LLM的語(yǔ)言生成速度。該研究中引入了多個(gè)Medusa頭和Tree attention機(jī)制,通過(guò)預(yù)測(cè)多個(gè)即將出現(xiàn)的標(biāo)記而非一個(gè)來(lái)優(yōu)化生成速度,同時(shí)還保持了高準(zhǔn)確率。此外,研究還提出了Typical acceptance方案,它基于原始模型的預(yù)測(cè)概率來(lái)選擇候選項(xiàng),而不是依賴(lài)重要性抽樣,使得創(chuàng)意輸出更為高效和自適應(yīng)。
在實(shí)際測(cè)試中,Medusa成功地將Vicuna模型的運(yùn)行速度提高了大約兩倍,證明了其在現(xiàn)實(shí)世界的聊天機(jī)器人環(huán)境中的實(shí)用性和效果。整體來(lái)看,Medusa為開(kāi)發(fā)更快、更有效的聊天機(jī)器人開(kāi)辟了新的可能,顯示出在語(yǔ)言模型生成領(lǐng)域的巨大潛力。
編輯:黃飛
評(píng)論
查看更多