本文主要介紹了 FP8 數(shù)據(jù)格式在大型模型訓(xùn)練中的應(yīng)用、挑戰(zhàn)及最佳實(shí)踐,展示了 FP8 在提升訓(xùn)練速度和效率方面的潛力和實(shí)際效果。
FP8 格式
在介紹 FP8 格式之前,我們需要回答一個(gè)問題:為什么需要討論 FP8?從圖中可以看出,近年來大模型所需的算力急劇增長(zhǎng),從 GPT-1 到 GPT-3,再到類似 GPT-4 的 GPT MOE 1.8T,算力需求增長(zhǎng)了數(shù)萬倍。這種增長(zhǎng)速度的背后是硬件算力的提升。訓(xùn)練過程中的一個(gè)重要指標(biāo)是訓(xùn)練時(shí)間。如果訓(xùn)練一個(gè)模型需要半年甚至一年,這在實(shí)際操作中是不可行的,因?yàn)閷?shí)際訓(xùn)練時(shí)間可能是理論值的兩到三倍。因此,算力基礎(chǔ)設(shè)施的提升是大模型迅速發(fā)展的基礎(chǔ)。
從算力角度來看,近年來 GPU 的單卡算力提升了大約一千倍,這包括工藝制程的改進(jìn)、硬件結(jié)構(gòu)的優(yōu)化以及更低的訓(xùn)練精度。隨著 FP8 的引入,其 Tensor Core 算力是 FP16 的兩倍,為探索更大規(guī)模的模型提供了算力支持。
具體來說,F(xiàn)P8 的優(yōu)勢(shì)包括:對(duì)于計(jì)算密集型算子,F(xiàn)P8 的 Tensor Core 相對(duì)于 BF16/FP16 能提供兩倍的算力,從而大大縮短計(jì)算時(shí)間;對(duì)于 Memory Bound 的算子,F(xiàn)P8 格式所需的數(shù)據(jù)量更少,可以節(jié)省緩存量,加快計(jì)算;如果將通信算子中的數(shù)據(jù)類型也替換成 FP8,也可以獲得一定的加速。最后,F(xiàn)P8 訓(xùn)練的模型可以更好地與推理相結(jié)合,因?yàn)槿绻P驮谟?xùn)練時(shí)的精度是 FP8,那么可以更快地部署到推理側(cè),而不需要額外的 PTQ 量化過程。
FP8 數(shù)據(jù)格式包含兩種:E5M2 和 E4M3。E 代表指數(shù)位,M 代表尾數(shù)位。E5M2 包含五個(gè)指數(shù)位和兩個(gè)尾數(shù)位,而 E4M3 包含四個(gè)指數(shù)位和三個(gè)尾數(shù)位。E5M2 由于有更多的指數(shù)位,動(dòng)態(tài)范圍更大;E4M3 有更多的尾數(shù)位,數(shù)值精度更好。這兩種數(shù)據(jù)格式在訓(xùn)練時(shí)都有各自的應(yīng)用場(chǎng)景。
在大模型訓(xùn)練中使用 FP8
FP8 帶來了更快的訓(xùn)練速度,但也對(duì)訓(xùn)練精度提出了挑戰(zhàn)。接下來將介紹在大模型訓(xùn)練中如何兼顧模型精度和訓(xùn)練速度。
在介紹 FP8 之前,我們先回顧一下 16 位精度訓(xùn)練中如何通過混合精度訓(xùn)練來維持精度。這里列出了四種混合精度訓(xùn)練的方法。第一種和最后一種嚴(yán)格來說不算混合精度,因?yàn)榈谝环N是純 FP32 訓(xùn)練,精度最好;最后一種是純 16 位精度訓(xùn)練,速度最快。為了兼顧速度和精度,我們列出了額外的兩種模式:AMP-O1 和 AMP-O2。AMP-O1 相對(duì)于 O0 的不同點(diǎn)在于它會(huì)維護(hù)一份白名單,白名單中的 OP 會(huì)以低精度進(jìn)行計(jì)算,如矩陣乘法和卷積算法,其他算子仍用高精度計(jì)算和存儲(chǔ)。AMP-O2 方案與 AMP-O3 更接近,不同點(diǎn)在于它會(huì)保留一些 unsafe 的 OP,這些 OP 會(huì)以 FP32 精度進(jìn)行存儲(chǔ)和計(jì)算,如 LayerNorm 和 Softmax 等。此外,還會(huì)保留一份 FP32 類型的 Master Weight,因?yàn)樵谀P陀?xùn)練后期,參數(shù)更新通常較慢,梯度值較小,容易出現(xiàn)大數(shù)加小數(shù)的問題,小數(shù)被吃掉,所以需要保留一份 FP32 的 Master Weight。目前 16 位精度訓(xùn)練基本上都是采用 AMP-O2 的混合精度方法來訓(xùn)練的。
FP8 訓(xùn)練可以認(rèn)為是一種 O1+O2 的混合模式。上邊這幅圖包含了前向和反向計(jì)算過程中的一些算子。紅色連接線表示高精度數(shù)據(jù),綠色連接線表示低精度數(shù)據(jù)。無論是前向還是反向,整體訓(xùn)練流程的精度仍然是 BF16 的,但會(huì)維護(hù)一份白名單,白名單中的 OP 以 FP8 精度計(jì)算,如 Linear Layer 中的矩陣乘法和 GELU。FP8 的 FMHA 目前在功能上是支持的,但在實(shí)際訓(xùn)練過程中通常還是用高精度的 FMHA,以保證更好的收斂性。對(duì)于 FP8,我們看到它是以 O1 的模式嵌入 BF16 訓(xùn)練的,BF16 訓(xùn)練本身又是一個(gè) O2 的混合精度方法,所以我們稱它為一種 O1+O2 的混合精度模式。
在訓(xùn)練過程中,前向和反向采用了不同的數(shù)據(jù)精度。前向用 E4M3,因?yàn)榍跋驎r(shí)數(shù)值的動(dòng)態(tài)范圍變化不大,用 E4M3 提供更好的精度;反向的梯度需要更大的動(dòng)態(tài)范圍,所以反向用 E5M2 數(shù)據(jù)格式。這個(gè)流程圖中,藍(lán)色框表示從 BF16 到 FP8 的 Cast 過程。這個(gè)過程不像 FP32 到 BF16 那樣簡(jiǎn)單直接。
接下來將詳細(xì)介紹 Cast 過程是如何實(shí)現(xiàn)的。因?yàn)?FP8 只有 4 位或 5 位指數(shù)位,小于 BF16,所以我們?yōu)榱吮苊庖绯龅那闆r,在 Cast 過程中需要做量化。因?yàn)?FP8 的動(dòng)態(tài)范圍有限,不足以表示模型的所有 Tensor,所以我們需要做 Per-tensor 的 Scaling。這與 FP16 不同,在 FP16 訓(xùn)練中我們做的是全局的量化。圖中形象地表示了 Scaling 的過程。綠色中括號(hào)表示 E4M3 的動(dòng)態(tài)范圍,能表示 2e-6 到 2e-8 次方范圍內(nèi)的值。紫色中括號(hào)表示當(dāng)前 Tensor 的數(shù)值分布。顯然,如果將 Tensor 從 BF16 直接轉(zhuǎn)換到 FP8,會(huì)有相當(dāng)一部分值被直接 Flush 為 0,這些信息被丟棄,造成精度損失。我們的處理方式是給 Tensor 乘上一個(gè)系數(shù),使 Tensor 的所有值向右平移,直到落到 E4M3 的表達(dá)范圍內(nèi)。這樣,BF16 類型的 Tensor 就可以比較安全地 Cast 到 FP8。這個(gè)就是 Per-tensor Scaling 的過程,這個(gè)系數(shù)我們稱為 Scaling factor。
接下來面臨的問題是我們?cè)趺磥泶_定 Scaling factor。一種直接的方式是我們?cè)谟?jì)算得到一個(gè)高精度結(jié)果之后,通過類似于 torch.max() 這樣的一個(gè)算子,找到 Tensor 的最大值,通過這個(gè)最大值來計(jì)算 Scaling factor,然后再量化高精度的 Tensor 為 FP8 的輸出。但這種方法的問題是因?yàn)?Tensor 的 Shape 通常都會(huì)比較大,我們是沒有辦法把這個(gè) Tensor 全部放到 GPU 的片上緩存 Shared Memory 中的。所以這個(gè)過程必須要借助 Global Memory 來進(jìn)行數(shù)據(jù)的中轉(zhuǎn),這就會(huì)帶來額外的一個(gè)訪存開銷。如果我們能提前知道 Scaling factor 的值,量化過程就可以提前到片上緩存 Shared Memory 中去完成。這時(shí)我們不需要等 Find Maximum 的值,F(xiàn)ind Maximum 和 Scale 操作可以同時(shí)在片上緩存完成,從而避免額外的訪存開銷。
這里我們提前獲取 Scaling factor 的方式是 Delayed Scaling Recipe。這種方式的思想是通過當(dāng)前 Tensor 的歷史迭代步信息來估計(jì)當(dāng)前 Tensor 的最大值。
具體的,我們會(huì)建立一個(gè) Amax History Buffer,記錄一個(gè) Tensor 在歷史迭代步中的最大值。當(dāng)需要當(dāng)前 Tensor 的 Scaling factor 時(shí),會(huì)從 History Buffer 中選出一個(gè)最大值,作為當(dāng)前 Tensor 最大值的估計(jì)。有了最大值之后,可以計(jì)算 Scaling factor,從而對(duì)當(dāng)前 Tensor 進(jìn)行 FP8 量化。
另一方面,當(dāng)要輸出 FP8 Tensor 時(shí),我們會(huì)統(tǒng)計(jì)當(dāng)前 Tensor 真實(shí)的最大值。將真實(shí)的最大值 New Amax 追加到 History Buffer 中。因?yàn)?History Buffer 是有長(zhǎng)度的,所以當(dāng)新的 Amax 追加到 History Buffer 末尾后,最前面的信息會(huì)被丟棄掉。這樣,可以一直用最近的歷史信息來估計(jì)當(dāng)前 Tensor 的最大值。
接下來我們把 Delayed Scaling Recipe 過程放到一個(gè)真實(shí)的場(chǎng)景來介紹它是如何工作的:
圖中左邊部分是一個(gè) Activation OP,輸入是一個(gè)高精度的 Activation,輸出是一個(gè) FP8 的 Tensor。右邊是一個(gè) Tensor Core 的 OP,輸入是 FP8 的 Tensor,輸出是一個(gè)高精度的值。這兩個(gè) OP 可以類比到 Transformer Layer 里面的 Layer Norm 和 FC1。對(duì)于 Activation OP 來說,它的輸入和計(jì)算過程都是高精度的。當(dāng)我們得到一個(gè)高精度的結(jié)果之后,我們會(huì)做兩件事:
第一件事,會(huì)統(tǒng)計(jì)當(dāng)前 Tensor 的一個(gè)最大值,并將其追加到 History Buffer 中。同時(shí)另外一件事,我們會(huì)從 History Buffer 中選出一個(gè)最大值,作為當(dāng)前 Tensor 最大值的估計(jì),并計(jì)算出 Scaling factor,繼而將當(dāng)前的 FP16 類型的 Tensor 量化到 FP8 進(jìn)行輸出。對(duì)于 Tensor Core 的 OP 來說,它的 Activation 的輸入已經(jīng)是 FP8 的 Tensor,權(quán)重也是用 Delayed Scaling Recipe 的方式來將其量化到 FP8。
這樣,我們將 GEMM 的所有的輸入都轉(zhuǎn)換成了 FP8,就可以用 FP8 的 Tensor Core 來進(jìn)行計(jì)算,計(jì)算的結(jié)果是一個(gè)高精度的結(jié)果。
在輸出最終的結(jié)果之前,我們需要一個(gè)反量化的過程。這是因?yàn)?GEMM 的輸入對(duì) Activation 和 Weight 都做了量化,所以它的值都被相應(yīng)的左移或者右移。這就是 Megatron Core 框架里現(xiàn)在集成 FP8 訓(xùn)練的一個(gè)方式。
FP8 訓(xùn)練性能
接下來會(huì)介紹 FP8 訓(xùn)練的性能結(jié)果:
這里使用的軟件鏡像是 NeMo Framework v24.01。我們可以看到在 Llama 模型上,F(xiàn)P8 訓(xùn)練在訓(xùn)練吞吐上的加速比在 33%-45% 范圍內(nèi)。通過觀察 GPU 上的 Nsight System Report,發(fā)現(xiàn) FP8 訓(xùn)練的 Timeline 里面 Kernel 之間很容易出現(xiàn)氣泡。這個(gè)問題出現(xiàn)的原因是我們將代碼中最耗時(shí)的矩陣乘 Kernel 換成了 FP8,雖然它的計(jì)算時(shí)間減半了,但因?yàn)?FP8 的 Delayed Scaling Recipe 引入了一些和 Amax 以及 Scaling factor 相關(guān)的操作,引入了額外的 Kernel,所以導(dǎo)致 Host 端 launch Kernel 的 overhead 變大。此消彼長(zhǎng),使得 Kernel launch 跟不上 Kernel 計(jì)算的速度,產(chǎn)生 Launch bound 的問題。
如何解決 Launch bound 問題,我們將在后續(xù)的內(nèi)容中介紹。
接下來是在另一張 GPU 上的測(cè)試結(jié)果,同樣也是在 Llama 模型上進(jìn)行預(yù)訓(xùn)練,鏡像是 NeMo Framework v24.01??梢杂^察到 FP8 相對(duì)于 BF6 的加速比大約是 60%-73%。
最后是 MOE 模型上的一些 Benchmark 結(jié)果,模型是 Mixtral 8x7B,軟件是基于 Megatron-Core v0.7 開發(fā)的 FP8 版本。在這個(gè)版本上 ,F(xiàn)P8 的加速比達(dá)到了 63%。
再分享下性能上的最佳實(shí)踐:
FP8 這部分的性能問題并不多,比較常見的是 Kernel 之間的氣泡問題。為了解決這個(gè)問題,首先我們可以從減小 Host 端 Kernel launch overhead 的角度出發(fā),盡可能地將這些 Kernel fuse 起來,減小 Kernel launch 的次數(shù)。比如我們可以將 Amax 以及 Scaling factor 相關(guān)的 Kerner fuse 起來,也可以把 Rotary Potential Embedding 這部分的 Kernel fuse 起來,以及 Swiglu 的 Fusion。除此之外,還可以利用 CUDA Graph 來將這些 Kernel 合并為一次 Graph 的 launch,來減小 Kernel launch 的開銷。另外,我們?cè)诖a中要盡量避免 Host 端與 Device 端同步,這個(gè)同步會(huì)強(qiáng)制阻塞 Host 端操作,從而加重 Launch bound 問題。我們?cè)谄綍r(shí)寫代碼的過程中用 Torch 的 OP 可能就會(huì)不經(jīng)意引入同步,我們?cè)谇跋蛴?jì)算完成之后可能會(huì)對(duì) Loss 進(jìn)行一些處理,比如檢查一些 NaN 之類的,就會(huì)引入 Host 端與 Device 端同步。
最后是關(guān)于超參調(diào)整的建議:在顯存允許的情況下,我們會(huì)推薦嘗試用更多的 PP 而不是 TP,因?yàn)?PP 的通信粒度會(huì)更粗一些,引入的 Kernel 會(huì)更少。另一方面可以調(diào)整訓(xùn)練的超參使得梯度累加的次數(shù)變少,當(dāng) GPU 的數(shù)量和問題規(guī)模比如 Global Batch Size 與 Synchronize 不變的情況下,一個(gè) Global Step 的計(jì)算量是恒定的,當(dāng)梯度累加次數(shù)越少時(shí),意味著每次梯度累加所分到的計(jì)算量就越大。同時(shí),每次梯度累加,Host 端 Kernel launch 的開銷是恒定的,所以當(dāng) Device 端的計(jì)算與 Host 端 launch 的開銷的比例達(dá)到一定程度,Launch bound 的問題就可以被減輕甚至直接消除掉。
這里關(guān)于超參調(diào)整的建議是從減小 Kernel 之間氣泡的角度出發(fā)的,實(shí)際在訓(xùn)練過程中做超參調(diào)整時(shí),要考慮的因素要更多,要做全盤的考慮。
FP8 訓(xùn)練過程中的收斂性
接下來介紹 FP8 收斂性相關(guān)的信息,收斂性的結(jié)果我們按照大語言模型訓(xùn)練的不同階段分別進(jìn)行介紹:
首先是預(yù)訓(xùn)練階段。我們訓(xùn)練了一個(gè)模型結(jié)構(gòu)和 Llama2-7B 相同的模型,數(shù)據(jù)集采用的是開源的 RedPajama,處理之后的數(shù)據(jù)量是 1.4T Token,超參用的是 TP2_PP1_DP128。
這里給出了從 300B 到 1.4T Token 之間的 Loss 曲線,以及訓(xùn)練末期 FP8 和 BF16 Loss 的差值。
通過這兩幅圖可以看到 FP8 和 BF16 的 Loss 曲線是很接近的,二者之間 Loss 差值在10-3量級(jí)。
接下來是下游任務(wù)的結(jié)果,這里選取了 MMLU 和 LM-Harness 等幾個(gè)任務(wù),對(duì)訓(xùn)練好的模型進(jìn)行評(píng)測(cè)。
結(jié)果顯示 FP8 和 BF16 訓(xùn)練的模型在不同任務(wù)上的得分會(huì)有高低,但總體相差不大。所以,對(duì)于預(yù)訓(xùn)練來說,F(xiàn)P8 訓(xùn)練的模型無論是 Loss 曲線還是下游任務(wù)都可以和 BF16 匹配的很好。
接下來的場(chǎng)景是模型的增量訓(xùn)練,比如從外部獲取到一個(gè)預(yù)訓(xùn)練模型,我們需要給它注入新的知識(shí),對(duì) LLaMa 系列模型添加中文知識(shí)等等。我們選擇了從 HuggingFace 下載的 Llama2-7B 模型,并使用 Open-Web-Math 數(shù)據(jù)集進(jìn)行訓(xùn)練。訓(xùn)練配置包括 TP1_PP2_DP4,我們提供了 FP8 和 BF16 的 Loss 曲線及其差值,可以看到同樣它們之間的 Loss 差值也非常小。
在下游任務(wù)評(píng)測(cè)中,我們選擇了 GSM8K 任務(wù),并使用 OpenCompass 工具進(jìn)行評(píng)測(cè)。我們每隔 500 步對(duì)訓(xùn)練過程進(jìn)行評(píng)分,以追蹤下游任務(wù)的變化。結(jié)果顯示,F(xiàn)P8 和 BF16 訓(xùn)練的下游任務(wù)得分總體走勢(shì)一致,且都顯示出一定的波動(dòng)性。兩次訓(xùn)練中,F(xiàn)P8 和 BF16 的最高得分相近,表明 FP8 在增量訓(xùn)練場(chǎng)景下的表現(xiàn)與 BF16 類似。
在 SFT(Supervised Fine-Tuning)結(jié)果中,我們選擇了 Llama2 系列的三個(gè)模型,并使用開源的三個(gè)數(shù)據(jù)集混合進(jìn)行訓(xùn)練。評(píng)測(cè)任務(wù)為 MT-Bench。從 Loss 曲線和下游任務(wù)結(jié)果來看,這三個(gè)模型的 FP8 Loss 曲線和下游任務(wù)得分都能與 BF16 對(duì)齊,證明了 FP8 在 SFT 訓(xùn)練中的可行性。
在實(shí)際訓(xùn)練過程中,我們沒有 BF16 baseline,因此需要通過其他方法判斷 FP8 是否處于正確的收斂路徑。一種方法是參考 01-AI 的做法,定期用 BF16 跑一定步數(shù)(如 100 到 200 步),作為 BF16 reference。通過比較 FP8 和 BF16 reference 的 Loss 曲線及下游任務(wù)得分,如果它們接近,則認(rèn)為 FP8 訓(xùn)練正確。如果差距較大,則用 BF16 替代 FP8 完成該期間的訓(xùn)練。另一種方法是不一定需要 BF16 baseline,只要 FP8 訓(xùn)練的 Loss 曲線持續(xù)下降,下游任務(wù)得分持續(xù)提升,就認(rèn)為 FP8 處于正確的收斂路徑。
在對(duì)比 FP8 和 BF16 的下游任務(wù)得分時(shí),應(yīng)正確看待二者在下游任務(wù)上的得分差異。
Meta 最近發(fā)表的論文選取了上百個(gè)模型,這些模型除了初始化的隨機(jī)數(shù)種子不同外,其他配置和環(huán)境相同。論文統(tǒng)計(jì)了這些模型在不同下游任務(wù) Benchmark 上的統(tǒng)計(jì)值,包括均值、方差、95% 的置信區(qū)間及單調(diào)性等。從方差及 95% 置信區(qū)間來看,不同隨機(jī)數(shù)種子對(duì)下游任務(wù)得分影響較大。例如,AGIEval 的平均得分是 23.44,95% 置信區(qū)間是 1.63,意味著模型得分在 21.8 到 25 之間都是合理的。GSM8K 的平均得分是 4.1,置信區(qū)間是 0.87,意味著模型得分在 3.2 到 4.9 之間都是合理的。盡管 FP8 和 BF16 之間的區(qū)別與隨機(jī)數(shù)種子的影響不同,但仍有助于設(shè)定 FP8 訓(xùn)練的下游任務(wù)預(yù)期。當(dāng) FP8 的下游任務(wù)得分落在 BF16 的 95% 置信區(qū)間范圍內(nèi),應(yīng)認(rèn)定 FP8 訓(xùn)練的模型與 BF16 匹配。
在收斂性過程中遇到問題的 Debugging Practices,可以將問題分為幾類。
首先是非 FP8 的問題,嘗試用 BF16 進(jìn)行 Resume training,如果損失曲線與 FP8 一致,問題可能與 FP8 無關(guān),而是與數(shù)據(jù)集或其他模塊相關(guān)。
第二類是軟件相關(guān)的 Bug,可以嘗試最新軟件?;蚯袚Q到 Transformer Engine/Megatron Core 的最近的幾個(gè)穩(wěn)定版進(jìn)行調(diào)試。
第三類是 Scaling factor 的問題,可以嘗試更保守的 Recipe,如 just-in-time 的 Scaling factor,如 Current scaling,來消除 Scaling factor 引起的誤差。另外,嘗試用 BF16 替代 FP8 來定位是哪個(gè) GEMM 引起的問題。
最后是 Evaluation 過程中的問題,F(xiàn)P8 訓(xùn)練的模型用 BF16 推理可能得到偏低分?jǐn)?shù)。因?yàn)?FP8 數(shù)值格式的原因,一些精度會(huì)被丟棄,用 BF16 推理時(shí)這些信息會(huì)被重新引入,反而成為噪聲。對(duì)于精度要求高的任務(wù),如 MMLU,會(huì)產(chǎn)生較大影響。推薦大家在用 BF16 推理遇到精度問題時(shí)不妨試試用 FP8 進(jìn)行推理。如果訓(xùn)練時(shí)的混合精度是 FP8 加 BF16,推理時(shí)則不能轉(zhuǎn)為 FP16 精度,因?yàn)?FP16 動(dòng)態(tài)范圍有限,又沒有應(yīng)用 Per-tensor scaling,容易出現(xiàn)上溢問題,所以需要推理精度仍然保持為 BF16。最后推薦用多點(diǎn)采樣的方式避免 Evaluation 過程中的噪聲問題。
展望
最后是展望和思考:
除了 Delayed Scaling 外,我們還進(jìn)行了其他實(shí)驗(yàn),如 Current Scaling 和 Block Scaling。
Current Scaling 使用當(dāng)前 Tensor 的最大值計(jì)算 Scaling factor,再將其 Cast 到 FP8,使用 Just-in-time 的 Scaling factor,對(duì)當(dāng)前 Tensor 有更好的表示,不會(huì)出現(xiàn)上溢情況。
除此之外,我們還嘗試了更細(xì)粒度的 Scaling Recipe,如 Block Scaling 和 Per-channel Scaling,以每個(gè) Block 或每行為一組,計(jì)算各自的 Amax 及 Scaling factor,保證不出現(xiàn)上溢情況下,減小下溢比例。因?yàn)?FP8 Tensor 附帶多組 Scaling factor,常規(guī) FP8 GEMM 不支持這種情況,需要對(duì) FP8 GEMM 進(jìn)行改造,定制化開發(fā)來支持這種情況。
最后回顧下低精度訓(xùn)練的發(fā)展過程與思考:
首先大語言模型經(jīng)歷了從 32 位精度訓(xùn)練到 16 位精度訓(xùn)練的轉(zhuǎn)變,遇到了訓(xùn)練不穩(wěn)定性的問題。各個(gè)公司提出了不同解決方案,如 Google 通過跳過一定 Data batch 來消除 Loss spike,Meta 則通過修改 Learning Rate、Weight Decay 及模型結(jié)構(gòu)等來優(yōu)化。最終發(fā)現(xiàn)將數(shù)據(jù)類型從 FP16 換成 BF16 可以很好地解決這些問題,因?yàn)?BF16 具有更大的動(dòng)態(tài)范圍,被廣泛應(yīng)用于各大公司的大模型訓(xùn)練上。
現(xiàn)在正在經(jīng)歷 16 位精度到 8 比特精度轉(zhuǎn)化的過程,嘗試更穩(wěn)定的 Scaling Recipe,如 Current Scaling 或 Block Scaling,或通過修改模型結(jié)構(gòu)提高訓(xùn)練穩(wěn)定性。
總結(jié)
我們選擇更低精度的出發(fā)點(diǎn)是為了加快訓(xùn)練速度,更快的訓(xùn)練速度意味著可以用更多數(shù)據(jù)訓(xùn)練更大模型,根據(jù) Scaling Law 得到更好模型效果,或者在更短的時(shí)間內(nèi)訓(xùn)練出性能相當(dāng)?shù)哪P?。另一方面,低精度?xùn)練格式天然對(duì)模型訓(xùn)練效果有影響,因此需要找到方法使 FP8 訓(xùn)練在絕大多數(shù) Case 下穩(wěn)定收斂,達(dá)到與高精度訓(xùn)練相近的模型效果。
現(xiàn)在的 Delayed Scaling Recipe 在絕大多數(shù)場(chǎng)景下都可以很好的 Work,但仍有改進(jìn)空間,無論是使用更魯棒的 Scaling Recipe,還是針對(duì) FP8 訓(xùn)練的特點(diǎn)調(diào)整模型結(jié)構(gòu),NVIDIA 技術(shù)團(tuán)隊(duì)都在持續(xù)探索。無論如何,低精度訓(xùn)練是大模型訓(xùn)練的趨勢(shì),NVIDIA 技術(shù)團(tuán)隊(duì)將持續(xù)探索更好的 Scaling Recipe,讓大家更好地使用 FP8 訓(xùn)練,相應(yīng)進(jìn)展會(huì)不定期的分享給大家。
作者簡(jiǎn)介
劉宏斌
NVIDIA 加速計(jì)算專家,2020 年加入 NVIDIA DevTech 團(tuán)隊(duì),專注于 GPU 上深度學(xué)習(xí)模型的優(yōu)化加速。目前主要負(fù)責(zé)生成式人工智能模型的訓(xùn)練階段的加速優(yōu)化。
-
NVIDIA
+關(guān)注
關(guān)注
14文章
4986瀏覽量
103046 -
數(shù)據(jù)格式
+關(guān)注
關(guān)注
0文章
30瀏覽量
8893 -
模型
+關(guān)注
關(guān)注
1文章
3243瀏覽量
48836 -
算力
+關(guān)注
關(guān)注
1文章
977瀏覽量
14809
原文標(biāo)題:FP8 訓(xùn)練的挑戰(zhàn)及最佳實(shí)踐
文章出處:【微信號(hào):NVIDIA-Enterprise,微信公眾號(hào):NVIDIA英偉達(dá)企業(yè)解決方案】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論