寫在前面
今天給大家?guī)硪黄抖嗾Z言SFT可以顯著提高LLM數(shù)學(xué)推理能力》,來自知乎@promise
Paper:?https://arxiv.org/abs/2310.20246 Github:?https://github.com/microsoft/MathOctopus/tree/main 知乎:https://zhuanlan.zhihu.com/p/664504560
近來,不少研究工作都集中于如何通過instruction tuning的方式來提高大模型(LLMs)的復(fù)雜數(shù)學(xué)推理能力。但是,這些基于的LLMs研究基本都集中于單語言,如何訓(xùn)練一個(gè)多語言數(shù)學(xué)推理大模型依然丞待解決。
因此,在這篇論文中研究者們基于LLaMA探索并構(gòu)建了一系列的多語言數(shù)學(xué)推理大模型:MathOctopus。MathOctopus不僅可以廣泛地提高LLMs在多語言上推理的平均性能,而且與單語訓(xùn)練的模型相比在其對(duì)應(yīng)的語言測(cè)試中依然可以取得更加優(yōu)越的表現(xiàn)。
主要貢獻(xiàn)如下:
為了解決當(dāng)前多語言數(shù)學(xué)推理任務(wù)上訓(xùn)練數(shù)據(jù)短缺的問題,本文將英文的GSM8K數(shù)據(jù)集翻譯成10種不同的語言,并使用了特定的規(guī)則來校對(duì)翻譯后的語料,以確保數(shù)據(jù)的質(zhì)量。最終生成的數(shù)據(jù)用來構(gòu)建多語言數(shù)學(xué)推理訓(xùn)練數(shù)據(jù)集:MGSM8KInstruct。
基于MGSM8KInstruct數(shù)據(jù)集, 并結(jié)合不同的SFT策略和多語言拒絕采樣的訓(xùn)練方法,本文構(gòu)建了一系列有效地多語言數(shù)學(xué)推理大模型:MathOctopus。
更近一步,為了全面地驗(yàn)證當(dāng)前模型在多語言數(shù)學(xué)推理任務(wù)上的魯棒性和通用性,文章基于SVAMP構(gòu)建了out-of-domain(OOD)的多語言測(cè)試數(shù)據(jù)集MSVAMP。
經(jīng)過大量的實(shí)驗(yàn),本文總結(jié)出以下結(jié)論:
MathOctopus在多語言數(shù)學(xué)推理任務(wù)中,表現(xiàn)出了強(qiáng)大的性能。MathOctopus-7B 可以將LLmMA2-7B在MGSM不同語言上的平均表現(xiàn)從22.6%提升到40.0%。更進(jìn)一步,MathOctopus-13B也獲得了比ChatGPT更好的性能。
與只在單語言上訓(xùn)練的LLMs相比,MathOctopus在他們對(duì)應(yīng)的訓(xùn)練語言測(cè)試中也取得了更加卓越的效果。比如,MathOctopus-7B和,在英語GSM8K上訓(xùn)練的LLaMA2-7B相比,準(zhǔn)確率從42.3%提升到了50.8%。
盡管拒絕采樣方法之前在單語數(shù)學(xué)推理中證明是十分有效的方法,但是在多語言數(shù)學(xué)推理任務(wù)中,使用拒絕采樣進(jìn)行數(shù)據(jù)增強(qiáng),對(duì)MathOctopus帶來的增益相對(duì)有限。
數(shù)據(jù)收集
MGSM8KInstruct訓(xùn)練集
在多語言數(shù)學(xué)推理任務(wù)中,面臨的問題是在low-resource語言中缺乏相應(yīng)高質(zhì)量的訓(xùn)練數(shù)據(jù)集,為此,本文使用ChatGPT將英文的GSM8K數(shù)據(jù)集翻譯成多種語言,其中包括孟加拉語(Bn),中文(Zh),法語(Fr),德語(de),日語(Ja),俄語(Ru),西班牙語(Es),斯瓦希里語(Sw)和泰語(Th),并對(duì)翻譯后的語料進(jìn)行校對(duì),確保數(shù)據(jù)的質(zhì)量。基于此,構(gòu)建了MGSM8KInstruct多語言數(shù)學(xué)推理訓(xùn)練數(shù)據(jù)集。
平行訓(xùn)練語料樣例
交叉訓(xùn)練語料樣例
翻譯
在這篇文章中,使用了ChatGPT將英文的GSM8K訓(xùn)練集和他們對(duì)應(yīng)的 chain-of-thought(COT)回答翻譯成了十種語言。為了保證翻譯的質(zhì)量,本文在翻譯時(shí)使用的提示詞(prompt)中遵循以下規(guī)則:
翻譯前后人物和地點(diǎn)的名字保持一致。
翻譯前后數(shù)學(xué)公式保持不變。
所有的數(shù)字都用阿拉伯?dāng)?shù)字表示。
對(duì)于每種語言,在提示詞(prompt)中提供了兩個(gè)翻譯的例子。
下面是完整的翻譯提示詞
校對(duì)
在翻譯問題與答案后,ChatGPT生成的句子通常沒有語言翻譯錯(cuò)誤,但存在數(shù)學(xué)公式在翻譯前后不一致的情況。為了確保翻譯前后的準(zhǔn)確性,本文采取了以下做法,首先,提取翻譯后答案中的所有數(shù)學(xué)公式,然后與原英文數(shù)據(jù)集中的公式進(jìn)行比較,如果它們匹配,就認(rèn)為翻譯是準(zhǔn)確的。如果某一數(shù)據(jù)連續(xù)五次出現(xiàn)翻譯錯(cuò)誤,將刪除該數(shù)據(jù)。這樣做有助于確保翻譯的準(zhǔn)確性。
MSVAMP測(cè)試集
為了更近一步測(cè)試當(dāng)前LLMs在多語言數(shù)學(xué)推理任務(wù)上的魯棒性,本文在現(xiàn)有的SVAMP數(shù)據(jù)集的基礎(chǔ)上構(gòu)建了out-of-domain(OOD)多語言數(shù)學(xué)推理測(cè)試集MSVAMP。
測(cè)試集語料樣例
翻譯
由于這個(gè)數(shù)據(jù)集的答案只包含最終的數(shù)字答案而不包括chain-of-thought(COT)過程,所以我們使用google翻譯系統(tǒng)僅對(duì)問題進(jìn)行翻譯,本文將SVAMP測(cè)試集中1000條數(shù)據(jù)翻譯成和訓(xùn)練集中對(duì)應(yīng)的語言。
校對(duì)
為了確保翻譯的質(zhì)量:首先,翻譯后的句子再次被翻譯回英文,以檢查是否存在翻譯上的差異。此外,還有三名專業(yè)的人員對(duì)翻譯前后的意思是否一致進(jìn)行了審查,進(jìn)一步確保翻譯的準(zhǔn)確性。
MathOctopus
本文基于MGSM8KInstruct,為了讓模型擁有更多樣化的能力,本文提出了兩種不同的訓(xùn)練方式。
為了使模型更好地理解問題與答案,本文提出的第一種方法是parallel-training,即問題與回答是相同的語言。
為了幫助模型融匯貫通不同的語言,本文提出的第二種方法是cross-training,即問題是英語,回答是別的語言,這可以使模型更好地解決多語言問題。
實(shí)驗(yàn)結(jié)果與分析
下圖是模型在MGSM測(cè)試集上的表現(xiàn),MathOctopusP 和 MathOctopusC 指的是模型訓(xùn)練方式分別為parallel-training和cross-training,xRFT 指的是多語言數(shù)學(xué)推理的拒絕采樣,LLaMA 指的是只在英語GSM8K上訓(xùn)練,RFT 指的是在英語GSM8K上訓(xùn)練后,進(jìn)行拒絕采樣。
下圖是模型在MGSM測(cè)試集上的表現(xiàn)
下圖是模型在MSVAMP測(cè)試集上的表現(xiàn)
根據(jù)實(shí)驗(yàn)結(jié)果,本文有以下發(fā)現(xiàn):
MathOctopus不論是在平行訓(xùn)練語料還是交叉訓(xùn)練語料上訓(xùn)練的結(jié)果都遠(yuǎn)超于其他開源的LLMs。例如,在7B模型上,MathOctopus在MGSM上的準(zhǔn)確率從22.6%提升到41.9%,MathOctopusP-13B在MGSM上的準(zhǔn)確率超過了ChatGPT。
MathOctopusP在in-domain測(cè)試集MGSM中表現(xiàn)效果更好,相反MathOctopusC在out-of-domain測(cè)試集MSVAMP中體現(xiàn)了更強(qiáng)的泛化能力。
多語言拒絕采樣在多語言數(shù)學(xué)推理任務(wù)中,對(duì)MathOctopus帶來的提升有限。
下圖展示了在GSM8K訓(xùn)練集上訓(xùn)練的LLaMA2和用MGSM8KInstruct訓(xùn)練的MathOctopus在GSM8K測(cè)試集和SVAMP上的表現(xiàn)
本文發(fā)現(xiàn),和只在單語言上訓(xùn)練的LLMs相比,MathOctopus在英語數(shù)據(jù)測(cè)試中也取得了更好的效果。為了進(jìn)一步探索在其他語言上是否有相同的現(xiàn)象,本文進(jìn)行了以下實(shí)驗(yàn):
隨機(jī)從訓(xùn)練集中挑選三種語言,分別是西班牙語,中文,泰語。使用它們對(duì)應(yīng)的訓(xùn)練語料分別訓(xùn)練三個(gè)模型,分別命名為 ES-LLaMA,CN-LLaMA,Th-LLaMA。下圖展示了這幾個(gè)模型在他們對(duì)應(yīng)訓(xùn)練語言下的測(cè)試結(jié)果。由圖可見,在單一語言上,MathOctopus的表現(xiàn)仍然超過了單語SFT模型的結(jié)果。這表明,在數(shù)學(xué)推理任務(wù)中,多語言訓(xùn)練比單語言訓(xùn)練有更好的效果。
多語言拒絕采樣
《Scaling relationship on learning mathematical reasoning with large language models》表明,拒絕采樣rejection sampling(RFT)可以大幅提升模型的表現(xiàn)。為了探究在多語言訓(xùn)練的場(chǎng)景下拒絕采樣對(duì)模型的提升效果,本文在得到多語言SFT模型后,采樣模型在MGSM8KInstruct數(shù)據(jù)集上的推理結(jié)果,對(duì)采樣到的推理過程進(jìn)行驗(yàn)證,如果符合要求則將其并入到原本的數(shù)據(jù)集。具體做法如下:
為了采樣到多樣化的推理答案,本文從MathOctopus-7B和MathOctopus-13B中分別采樣25條推理路徑,即每種語言總共采樣50次。
為了確保推理路徑的準(zhǔn)確性,本文提取推理路徑中的所有公式并對(duì)公式進(jìn)行驗(yàn)算,如果答案正確那么就認(rèn)為推理路徑是正確的。
為了確保推理路徑的多樣化,本文采用的策略是,只有當(dāng)前推理路徑和先前的路徑中沒有相同的公式時(shí),才將此路徑放入數(shù)據(jù)集中。
下圖展示了不同的采樣次數(shù)下,每種語言生成的不同推理路徑的個(gè)數(shù)
本文發(fā)現(xiàn),通過多語言拒絕采樣(xRFT)增加的數(shù)據(jù)對(duì)模型的提升效果有限,主要表現(xiàn)在以下幾點(diǎn):
在MGSM測(cè)試集上,多語言拒絕采樣只能提升MathOctopusP模型1%-2%的效果。
在MSVAMP測(cè)試集上,多語言拒絕采樣的提升效果不到1%。
多語言拒絕采樣對(duì)MathOctopusC的提升效果更小,在MGSM數(shù)據(jù)集上的表現(xiàn)反而有所下降。
為了探究xRFT生成的數(shù)據(jù)量對(duì)模型的影響,本文在三個(gè)不同的采樣次數(shù)(10,30,50)下分別探究對(duì)應(yīng)的模型在測(cè)試集上的表現(xiàn)。
下圖是不同采樣次數(shù)下模型在MGSM數(shù)據(jù)集上的表現(xiàn)
下圖是不同采樣次數(shù)下模型在MSVAMP數(shù)據(jù)集上的表現(xiàn)
可以發(fā)現(xiàn),在MGSM測(cè)試集上,當(dāng)拒絕采樣的次數(shù)越多,訓(xùn)練語料越多時(shí),MathOctopusP的表現(xiàn)也略微變好。與之相反,在MSVAMP數(shù)據(jù)集上,當(dāng)拒絕采樣的次數(shù)越多,訓(xùn)練語料越多時(shí),MathOctopusC的表現(xiàn)反而有所下降。
總結(jié)
目前僅研究到33B模型,將來還可以在LLaMA2-70B的基礎(chǔ)上探索更大的MathOctopus模型,除此之外,在這些更大的模型上使用多語言拒絕采樣也是將來的研究點(diǎn)之一。由于MathOctopus只有十種訓(xùn)練語言,更多的訓(xùn)練語言是否會(huì)給模型帶來更好的效果仍然有待研究。
編輯:黃飛
?
評(píng)論
查看更多