該論文介紹了一種名為 ReMax 的新算法,專為基于人類反饋的強化學(xué)習(xí)(RLHF)而設(shè)計。ReMax 在計算效率(約減少 50% 的 GPU 內(nèi)存和 2 倍的訓(xùn)練速度提升)和實現(xiàn)簡易性(6 行代碼)上超越了最常用的算法 PPO,且性能沒有損失。
論文鏈接:https://arxiv.org/abs/2310.10505
作者:李子牛,許天,張雨舜,俞揚,孫若愚,羅智泉
機構(gòu):香港中文大學(xué)(深圳),深圳市大數(shù)據(jù)研究院,南京大學(xué),南棲仙策
開源代碼:https://github.com/liziniu/ReMax
如未額外說明,所有圖片來自于論文。 背景 今年,以 ChatGPT 為首的大語言模型(Large Language Models, LLMs) 在各個方面大放光彩,由此引發(fā)了學(xué)術(shù)界和商業(yè)界對 GPU 等計算資源的需求劇增。
比如監(jiān)督訓(xùn)練地調(diào)優(yōu) (supervised fine-tuning, SFT) 一個 Llama2-7B 的模型,需要消耗 80GB 以上的內(nèi)存。而這往往不夠,為了和人類對齊(alignment),大語言模型還要經(jīng)過 RLHF (reinforcement learning from human feedback) 的訓(xùn)練。RLHF 的 GPU 消耗往往是 SFT 的 2 倍以上,訓(xùn)練時間更能達到 6 倍以上。 近日,美國政府宣布限制英偉達 GPU 產(chǎn)品 H100, H800等進入中國市場。這項條款無疑為中國發(fā)展大語言模型(LLMs) 和人工智能增添了很多阻力。減小 RLHF 的訓(xùn)練成本(GPU 消耗和訓(xùn)練時間)對 LLMs 的發(fā)展非常重要。 動機 RLHF 包含三個階段: 1. 監(jiān)督式地調(diào)優(yōu)(Supervised Fine-Tuning, SFT)。 2. 從對比數(shù)據(jù)中學(xué)習(xí)獎勵模型(reward model)。 3. 利用強化學(xué)習(xí)(RL)算法來最大化獎勵。
圖片來源自 InstructGPT 論文 我們發(fā)現(xiàn) RLHF 的主要計算開銷來源于第三階段(獎勵最大化)。這一點可以從 DeepSpeed-Chat 的報告里看到,第三階段的訓(xùn)練時間是前兩個階段時間總和的 4 倍以上。而且,根據(jù)我們的經(jīng)驗,第三階段的 GPU 消耗是前兩階段的 2 倍以上。
圖片來自 DeepSpeed-Chat 技術(shù)報告 目前 RLHF 第 3 階段的主要計算瓶頸是什么? 我們發(fā)現(xiàn)該階段的計算瓶頸主要來源用來目前使用的 RL 算法:PPO 算法。PPO 算法是用來解決普適 RL 問題的最流行的算法之一,有非常多成功的案例。我們在這里省略 PPO 的技術(shù)細節(jié),著重介紹 PPO 的一個關(guān)鍵組件:價值模型 (The value model)。價值模型是一個需要被訓(xùn)練的神經(jīng)網(wǎng)絡(luò),能夠有效地估計給定策略的預(yù)期長期回報。盡管價值模型為 PPO 帶來了良好的性能,但它在 RLHF 任務(wù)中也引入了沉重的計算開銷。例如,為了更好地與人類偏好對齊,PPO 中的價值模型通常與 LLM 大小相似,這使存儲需求翻了一番。此外,價值模型的訓(xùn)練需要存儲其梯度、激活和優(yōu)化器狀態(tài),這進一步增加了近 4 倍的 GPU 存儲需求??偨Y(jié)來說,PPO 和它的價值模型(以及其訓(xùn)練相關(guān)部分)已成為 RLHF 獎勵最大化階段的主要計算障礙。
相比 PPO,ReMax 是輕量級算法 思路是否有可能找到比 PPO 更適配 RLHF 的算法? 我們得出的答案是肯定的。這是因為 PPO 和價值模型是為通用 RL 問題設(shè)計的,而不是針對像 RLHF 這樣的特定問題(RLHF 只是 RL 問題中的一個子類)。有趣的是,我們發(fā)現(xiàn) RLHF 具有三個在 PPO 中未使用的重要結(jié)構(gòu): 1. 快速模擬(fast simulation): 軌跡(即 LLM 中的整個響應(yīng))可以在很短的時間內(nèi)迅速執(zhí)行(小于 1s),幾乎沒有時間開銷。 2. 確定性轉(zhuǎn)移(deterministic transitions):上下文確定性依賴于過去的標(biāo)記和當(dāng)前生成的標(biāo)記。 3. 軌跡級獎勵(trajectory-level rewards):獎勵模型只在響應(yīng)完成時提供一個獎賞值。 通過這三個觀察,我們不難發(fā)現(xiàn) value model 在 RLHF 的問題中是 “冗余” 的。這是因為 value model 設(shè)計的初衷是為了隨機環(huán)境下的樣本效率和慢仿真環(huán)境的計算效率。然而這在 RLHF 中是不需要的。
ReMax 是針對 RLHF 設(shè)計的算法,PPO 則是為通用 RL 設(shè)計的算法 方法ReMax ReMax 算法基于一個古老的策略梯度算法 REINFORCE,REINFORCE 使用的策略梯度估計器如下圖所示:
REINFORCE 梯度估計器
REINFORCE可以在計算層面利用好RLHF任務(wù)的三個性質(zhì),因為REINFORCE直接利用一個響應(yīng)的獎勵來進行優(yōu)化,不需要像一般的RL算法一樣需要知道中間步驟的獎勵和值函數(shù)。然而,由于策略的隨機性, REINFORCE梯度估計器存在高方差問題(在Richard Sutton的RL書里有指出),這一問題會影響模型訓(xùn)練的有效性,因此REINFORCE在RLHF任務(wù)中的效果較差,見下面兩張圖片。
REINFORCE 的計算代價小,但性能差
REINFORCE 的(隨機)梯度值遠遠大于 ReMax 為解決這一問題,ReMax 使用貪婪生成的回答(greedy response)的獎勵作為基準(zhǔn)值(baseline value)來構(gòu)建梯度估計器,具體公式如下:
ReMax 梯度估計器 注意到,貪婪回復(fù)的獎勵可以看作為期望獎勵的好的近似。在理想情形下(),對于隨機變量,
,因此我們能夠期望估計器具有更小的方差。 ? ? 下圖展示了 ReMax 的算法流程,紅色方框中的是核心算法改變。 ?
ReMax 算法流程 理論保證 我們證明了 ReMax 使用的梯度估計器仍然是真實策略梯度的一個無偏估計器。 詳細理論介紹見論文。 算法優(yōu)點
ReMax 的核心部分可以用 6 行代碼來實現(xiàn)。相比之下,PPO 要額外引入重要性采樣(importance sampling),廣義優(yōu)勢估計(generalized advantage estimation,GAE),價值模型學(xué)習(xí)等額外模塊。
ReMax 的超參數(shù)很少。相比之下,PPO 有額外的超參數(shù),例如重要性采樣剪切閾值(importance sampling clipping ratio)、GAE 系數(shù)、價值模型學(xué)習(xí)率,離策略訓(xùn)練輪次(off-policy training epoch)等,這些超參數(shù)都需要花大量時間去調(diào)優(yōu)。
ReMax 能理論上節(jié)省約 50% 內(nèi)存。相比于 PPO,ReMax 成功移除了所有和價值模型相關(guān)的部件,大大減小了內(nèi)存開銷。通過計算,我們發(fā)現(xiàn)相比于 PPO,ReMax 能節(jié)省約 50% 內(nèi)存。
效果有效性
ReMax 可以像 PPO 一樣有效地最大化獎勵
在 OPT-1.3B 上,ReMax 可以有效地最大化獎勵
在 OPT-1.3B 上,ReMax 的訓(xùn)練非常穩(wěn)定
在 GPT-4 評估下(LIMA Test Questions),ReMax 得到的策略比 SFT 和 PPO 會更好
GPT4 打分顯示 ReMax 得到的模型會更好 高效性
ReMax 能節(jié)省近 50% 的 GPU 內(nèi)存。ReMax 移除掉了價值模型和它的訓(xùn)練部分(梯度,優(yōu)化器,激活值),從而極大節(jié)省了 GPU 內(nèi)存需求??紤] Llama2-7B,PPO 無法在 8xA100-40GB 的機器上跑起來,但是 ReMax 可以。
在 Llama2-7B 上,ReMax 可以節(jié)省近 50% 的 GPU 內(nèi)存
ReMax 能加快 2 倍的訓(xùn)練速度。在每一輪中,ReMax 調(diào)用 2 次生成(generation),1 次反向傳播(backpropagation);而 PPO 使用 1 次生成,2 次反向傳播。對于大模型而言,生成會比反向傳播的時間小,從而 ReMax 可以實現(xiàn)理論上接近 2 倍的訓(xùn)練加速。
通用性 除了 RLHF 任務(wù),作為一個 RL 算法,ReMax 對于經(jīng)典的 NLP 任務(wù)也適用。本文考慮了在 GPT-2 上進行一個電影評論續(xù)寫的任務(wù),這里獎勵模型不是從對比數(shù)據(jù)學(xué)習(xí)的。實驗觀測到,ReMax 可以實現(xiàn) 2.2 倍的訓(xùn)練加速和 60% 的 GPU 內(nèi)存節(jié)省。
在經(jīng)典的 NLP 任務(wù)(文本續(xù)寫)上,ReMax 相比 PPO 實現(xiàn)了 2.2 倍加速 總結(jié) 最后,我們從實驗中簡要總結(jié)了 ReMax 相對于 PPO 的主要優(yōu)勢。
更簡單的實現(xiàn): ReMax 的核心部分 6 行代碼即可實現(xiàn)。這與 PPO 中的眾多復(fù)雜的代碼構(gòu)建塊形成鮮明對比。
更少的內(nèi)存開銷:由于移除了價值模型及其全部訓(xùn)練組件,相比 PPO,ReMax 節(jié)省了大約 50% 的 GPU 內(nèi)存。
更少的超參數(shù): ReMax 成功移除了所有和價值模型訓(xùn)練相關(guān)的超參數(shù),其中包括:GAE 系數(shù)、價值模型學(xué)習(xí)率、重要性采樣時期、小批量(mini-batch)大小。這些超參數(shù)往往對問題敏感且難以調(diào)整。我們相信 ReMax 對 RLHF 研究者更加友好。
更快的訓(xùn)練速度:在 GPT2(137M)的實驗中,我們觀察到 ReMax 在真實運行時間方面相比于 PPO 有 2.2 倍的加速。加速來自 ReMax 每次迭代中較少的計算開銷。通過我們的計算,該加速優(yōu)勢在更大的模型上也能維持(假設(shè)在足夠大的內(nèi)存下 PPO 可以被成功部署)。
優(yōu)異的性能:如前所示,ReMax在中等規(guī)模實驗中與PPO實現(xiàn)了相當(dāng)?shù)男阅?,并且有時甚至超越它(可能是由于 ReMax 更容易找到合適的超參數(shù))。我們推測這種良好的性能可以拓展到更大規(guī)模的模型中。
-
語言模型
+關(guān)注
關(guān)注
0文章
526瀏覽量
10277 -
ChatGPT
+關(guān)注
關(guān)注
29文章
1562瀏覽量
7722 -
大模型
+關(guān)注
關(guān)注
2文章
2465瀏覽量
2761
原文標(biāo)題:在RTX 4090被限制的時代下,讓大模型使用RLHF更高效的方法來了
文章出處:【微信號:zenRRan,微信公眾號:深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論