學(xué)習(xí)率(learning rate)是調(diào)整深度神經(jīng)網(wǎng)絡(luò)最重要的超參數(shù)之一,本文作者Pavel Surmenok描述了一個(gè)簡單而有效的辦法來幫助你找尋合理的學(xué)習(xí)率。
我正在舊金山大學(xué)的 fast.ai 深度學(xué)習(xí)課程中學(xué)習(xí)相關(guān)知識。目前這門課程還沒有對公眾開放,但是現(xiàn)在網(wǎng)絡(luò)上有去年的版本。
學(xué)習(xí)率如何影響訓(xùn)練?
深度學(xué)習(xí)模型通常由隨機(jī)梯度下降算法進(jìn)行訓(xùn)練。隨機(jī)梯度下降算法有許多變形:例如 Adam、RMSProp、Adagrad 等等。這些算法都需要你設(shè)置學(xué)習(xí)率。學(xué)習(xí)率決定了在一個(gè)小批量(mini-batch)中權(quán)重在梯度方向要移動(dòng)多遠(yuǎn)。
如果學(xué)習(xí)率很低,訓(xùn)練會(huì)變得更加可靠,但是優(yōu)化會(huì)耗費(fèi)較長的時(shí)間,因?yàn)槌驌p失函數(shù)最小值的每個(gè)步長很小。
如果學(xué)習(xí)率很高,訓(xùn)練可能根本不會(huì)收斂,甚至?xí)l(fā)散。權(quán)重的改變量可能非常大,使得優(yōu)化越過最小值,使得損失函數(shù)變得更糟。
訓(xùn)練應(yīng)當(dāng)從相對較大的學(xué)習(xí)率開始。這是因?yàn)樵陂_始時(shí),初始的隨機(jī)權(quán)重遠(yuǎn)離最優(yōu)值。在訓(xùn)練過程中,學(xué)習(xí)率應(yīng)當(dāng)下降,以允許細(xì)粒度的權(quán)重更新。
有很多方式可以為學(xué)習(xí)率設(shè)置初始值。一個(gè)簡單的方案就是嘗試一些不同的值,看看哪個(gè)值能夠讓損失函數(shù)最優(yōu),且不損失訓(xùn)練速度。我們可以從 0.1 這樣的值開始,然后再指數(shù)下降學(xué)習(xí)率,比如 0.01,0.001 等等。當(dāng)我們以一個(gè)很大的學(xué)習(xí)率開始訓(xùn)練時(shí),在起初的幾次迭代訓(xùn)練過程中損失函數(shù)可能不會(huì)改善,甚至?xí)龃?。?dāng)我們以一個(gè)較小的學(xué)習(xí)率進(jìn)行訓(xùn)練時(shí),損失函數(shù)的值會(huì)在最初的幾次迭代中從某一時(shí)刻開始下降。這個(gè)學(xué)習(xí)率就是我們能用的最大值,任何更大的值都不能讓訓(xùn)練收斂。不過,這個(gè)初始學(xué)習(xí)率也過大了:它不足以訓(xùn)練多個(gè) epoch,因?yàn)殡S著時(shí)間的推移網(wǎng)絡(luò)將需要更加細(xì)粒度的權(quán)重更新。因此,開始訓(xùn)練的合理學(xué)習(xí)率可能需要降低 1-2 個(gè)數(shù)量級。
一定有更好的方法
Leslie N. Smith?在 2015 年的論文「Cyclical Learning Rates for Training Neural Networks」的第 3.3 節(jié),描述了一種為神經(jīng)網(wǎng)絡(luò)選擇一系列學(xué)習(xí)率的強(qiáng)大方法。
訣竅就是從一個(gè)低學(xué)習(xí)率開始訓(xùn)練網(wǎng)絡(luò),并在每個(gè)批次中指數(shù)提高學(xué)習(xí)率。
在每個(gè)小批量處理后提升學(xué)習(xí)率
為每批樣本記錄學(xué)習(xí)率和訓(xùn)練損失。然后,根據(jù)損失和學(xué)習(xí)率畫圖。典型情況如下:
一開始,損失下降,然后訓(xùn)練過程開始發(fā)散
首先,學(xué)習(xí)率較低,損失函數(shù)值緩慢改善,然后訓(xùn)練加速,直到學(xué)習(xí)速度變得過高導(dǎo)致?lián)p失函數(shù)值增加:訓(xùn)練過程發(fā)散。
我們需要在圖中找到一個(gè)損失函數(shù)值降低得最快的點(diǎn)。在這個(gè)例子中,當(dāng)學(xué)習(xí)率在 0.001 和 0.01 之間,損失函數(shù)快速下降。
另一個(gè)方式是觀察計(jì)算損失函數(shù)變化率(也就是損失函數(shù)關(guān)于迭代次數(shù)的導(dǎo)數(shù)),然后以學(xué)習(xí)率為 x 軸,以變化率為 y 軸畫圖。
損失函數(shù)的變化率
上圖看起來噪聲太大,讓我們使用簡單移動(dòng)平均線(SMA)來做平緩化處理。
使用 SMA 平緩化處理后的損失函數(shù)變化率
這樣看起來就好多了。在這個(gè)圖中,我們需要找到最小值位置??雌饋恚咏趯W(xué)習(xí)率為 0.01 這個(gè)位置。
實(shí)現(xiàn)代碼教程
Jeremy Howard 和他在 USF 數(shù)據(jù)研究所的團(tuán)隊(duì)開發(fā)了 fast.ai。這是一個(gè)基于 PyTorch 的高級抽象的深度學(xué)習(xí)庫。fast.ai 是一個(gè)簡單而強(qiáng)大的工具集,可以用于訓(xùn)練最先進(jìn)的深度學(xué)習(xí)模型。Jeremy 在他最新的深度學(xué)習(xí)課程()中使用了這個(gè)庫。
fast.ai 提供了學(xué)習(xí)率搜索器的一個(gè)實(shí)現(xiàn)。你只需要寫幾行代碼就能繪制模型的損失函數(shù)-學(xué)習(xí)率的圖像(來自 GitHub:plot_loss.py):
# learn is an instance of Learnerclass or one of derived classes like ConvLearner
learn.lr_find()
learn.sched.plot_lr()
庫中并沒有提供代碼繪制損失函數(shù)變化率的圖像,但計(jì)算起來非常簡單(plot_change_loss.py):
def plot_loss_change(sched, sma=1, n_skip=20, y_lim=(-0.01,0.01)):
"""
Plots rate of change of the loss function.
sched - learning rate scheduler, an instance of LR_Finder class.
sma - number of batches for simple moving average to smooth out the curve.
n_skip - number of batches to skip on the left.
y_lim - limits for the y axis.
"""
derivatives = [0] * (sma + 1)
for i in range(1 + sma, len(learn.sched.lrs)):
derivative = (learn.sched.losses[i] - learn.sched.losses[i - sma]) / sma
derivatives.append(derivative)
plt.ylabel("d/loss")
plt.xlabel("learning rate (log scale)")
plt.plot(learn.sched.lrs[n_skip:], derivatives[n_skip:])
plt.xscale('log')
plt.ylim(y_lim)
plot_loss_change(learn.sched, sma=20)
請注意:只在訓(xùn)練之前選擇一次學(xué)習(xí)率是不夠的。訓(xùn)練過程中,最優(yōu)學(xué)習(xí)率會(huì)隨著時(shí)間推移而下降。你可以定期重新運(yùn)行相同的學(xué)習(xí)率搜索程序,以便在訓(xùn)練的稍后時(shí)間查找學(xué)習(xí)率。
使用其他庫實(shí)現(xiàn)本方案
我還沒有準(zhǔn)備好將這種學(xué)習(xí)率搜索方法應(yīng)用到諸如 Keras 等其他庫中,但這應(yīng)該不是什么難事。只需要做到:
多次運(yùn)行訓(xùn)練,每次只訓(xùn)練一個(gè)小批量;
在每次分批訓(xùn)練之后通過乘以一個(gè)小的常數(shù)的方式增加學(xué)習(xí)率;
當(dāng)損失函數(shù)值高于先前觀察到的最佳值時(shí),停止程序。(例如,可以將終止條件設(shè)置為「當(dāng)前損失 > *4 最佳損失」)
學(xué)習(xí)計(jì)劃
選擇學(xué)習(xí)率的初始值只是問題的一部分。另一個(gè)需要優(yōu)化的是學(xué)習(xí)計(jì)劃(learning schedule):如何在訓(xùn)練過程中改變學(xué)習(xí)率。傳統(tǒng)的觀點(diǎn)是,隨著時(shí)間推移學(xué)習(xí)率要越來越低,而且有許多方法進(jìn)行設(shè)置:例如損失函數(shù)停止改善時(shí)逐步進(jìn)行學(xué)習(xí)率退火、指數(shù)學(xué)習(xí)率衰退、余弦退火等。
我上面引用的論文描述了一種循環(huán)改變學(xué)習(xí)率的新方法,它能提升卷積神經(jīng)網(wǎng)絡(luò)在各種圖像分類任務(wù)上的性能表現(xiàn)。?
評論