0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

如果項目的模型遇到瓶頸,用這些Tricks就對了

電子設計 ? 來源:電子設計 ? 作者:電子設計 ? 2020-12-10 14:33 ? 次閱讀
來源:AI人工智能初學者
作者:ChaucerG
其實圖像分類研究取得的大部分進展都可以歸功于訓練過程的改進,如數(shù)據(jù)增加和優(yōu)化方法的改變。但是,大多數(shù)改進都沒有比較詳細的說明。因此作者在本文中測試實現(xiàn)了這些改進的方法,并通過消融實驗來評估這些Tricks對最終模型精度的影響。作者通過將這些改進結合在一起,同時改進了各種CNN模型。在ImageNet上將ResNet-50的Top-1驗證精度從75.3%提高到79.29%。同時還將證明了提高圖像分類精度會在其他應用領域(如目標檢測和語義分割)也可以帶來更好的遷移學習性能。

1、Introduction

近年來ImageNet的榜單一直在被刷新,從2012年的AlexNet,再到VGG-Net、NiN、Inception、ResNet、DenseNet以及NASNet;Top-1精度也從62.5%(AlexNet)->82.7%(NASNet-A);但是這么大精度的提升也不完全是由模型的架構改變所帶來的,其中 訓練的過程也有會起到很大的作用,比如,損失函數(shù)的改進、數(shù)據(jù)的預處理方式的改變、以及優(yōu)化方法的選擇等;但是這也是很容易被忽略的部分,因此這篇文章在這里也會著重討論這個問題。

2、Efficient Training

近年來硬件發(fā)展迅速,特別是GPU。因此,許多與性能相關的權衡的最佳選擇也會隨之發(fā)生變化。例如,在訓練中使用較低的數(shù)值精度和較大的Batch/_Size更有效。
在本節(jié)中將在不犧牲模型精度的情況下實現(xiàn)低精度和大規(guī)模批量訓練的各種技術。有些技術甚至可以提高準確性和訓練速度。

2.1、Large-batch training

Mini-Batch SGD將多個樣本分組到一個小批量中,以增加并行性,降低傳輸成本。然而,使用Large Batch-size可能會減慢訓練進度。對于凸優(yōu)化問題,收斂率隨著批量大小的增加而降低。類似的經驗結論已經被發(fā)表。

換句話說,在相同的epoch數(shù)量下,使用Large Batch-size的訓練會與使用較小批次的訓練相比,模型的驗證精度降低。很多研究提出了啟發(fā)式搜索的方法來解決這個問題。下面將研究4種啟發(fā)式方法,可以在單臺機器訓練中擴大Batch-size的規(guī)模。

1)Linear scaling learning rate
在Mini-Batch SGD中,由于樣本是隨機選取的,所以梯度下降也是一個隨機的過程。增加批量大小不會改變隨機梯度的期望,但會減小隨機梯度的方差。換句話說,大的批量降低了梯度中的噪聲,因此我們可以通過提高學習率來在梯度相反的方向上取得更大的進展。

Goyal等人提出對于ResNet-50訓練,經驗上可以根據(jù)批大小線性增加學習率。特別是,如果選擇0.1作為批量大小256的初始學習率,那么當批量大小b變大時可以將初始學習率提高到:

2)Learning rate Warmup
在訓練開始時,所有參數(shù)通常都是隨機值,因此離最優(yōu)解很遠。使用過大的學習率可能導致數(shù)值不穩(wěn)定。在Warmup中,在一開始使用一個比較小的學習率,然后當訓練過程穩(wěn)定時切換回初始設置的學習率base/_lr。

Goyal等人提出了一種Gradual Warmup策略,將學習率從0線性地提高到初始學習率。換句話說,假設將使用前m批(例如5個數(shù)據(jù)epoch)進行Warmup,并且初始學習率為,那么在第批時將學習率設為i/=m。

3)Zero
一個ResNet網絡由多個殘差塊組成,而每個殘差塊又由多個卷積層組成。給定輸入,假設是Last Layer的輸出,那么這個殘差塊就輸出。注意,Block的最后一層可以是批處理標準化層。
BN層首先標準化它的輸入用表示,然后執(zhí)行一個scale變換。兩個參數(shù)、都是可學習的,它們的元素分別被初始化為1s和0s。在零初始化啟發(fā)式中,剩余塊末端的所有BN層初始化了。因此,所有的殘差塊只是返回它們的輸入,模擬的網絡層數(shù)較少,在初始階段更容易訓練。

4)No bias decay
權值衰減通常應用于所有可學習參數(shù),包括權值和偏差。它等價于應用L2正則化到所有參數(shù),使其值趨近于0。但如Jia等所指出,建議僅對權值進行正則化,避免過擬合。無偏差衰減啟發(fā)式遵循這一建議,它只將權值衰減應用于卷積層和全連通層中的權值。其他參數(shù),包括偏差和和以及BN層,都沒有進行正則化。
LARS提供了分層自適應學習率,并且對大的Batch-size(超過16K)有效。本文中單機訓練的情況下,批量大小不超過2K通常會導致良好的系統(tǒng)效率。

2.2、Low-precision training

神經網絡通常是用32位浮點(FP32)精度訓練的。也就是說,所有的數(shù)字都以FP32格式存儲,輸入和輸出以及計算操作都是FP32類型參與的。然而,新的硬件可能已經增強了新的算術邏輯單元,用于較低精度的數(shù)據(jù)類型。
例如,前面提到的Nvidia V100在FP32中提供了14個TFLOPS,而在FP16中提供了超過100個TFLOPS。如下表所示,在V100上從FP32切換到FP16后,整體訓練速度提高了2到3倍。

盡管有性能上的好處,降低的精度有一個更窄的范圍,使結果更有可能超出范圍,然后干擾訓練的進展。Micikevicius等人提出在FP16中存儲所有參數(shù)和激活,并使用FP16計算梯度。同時,F(xiàn)P32中所有的參數(shù)都有一個用于參數(shù)更新的副本。此外,損失值乘以一個比較小的標量scaler以更好地對齊精度范圍到FP16也是一個實際的解決方案。

2.3、Experiment Results

3、Model Tweaks

模型調整是對網絡架構的一個小調整,比如改變一個特定卷積層的stride。這樣的調整通常不會改變計算復雜度,但可能會對模型精度產生不可忽略的影響。

3.1、ResNet Tweaks

回顧了ResNet的兩個比較流行的改進,分別稱之為ResNet-B和ResNet-C。在此基礎上,提出了一種新的模型調整方法ResNet-D。

1)ResNet-B
ResNet-B改變的下采樣塊。觀察到路徑A中的卷積忽略了輸入feature map的四分之三,因為它使用的內核大小為1×1,Stride為2。ResNet-B切換路徑A中前兩個卷積的步長大小,如圖a所示,因此不忽略任何信息。由于第2次卷積的kernel大小為3×3,路徑a的輸出形狀保持不變。

2)ResNet-C
卷積的計算代價是卷積核的寬或高的二次項。一個7×7的卷積比3×3的卷積的計算量更大。因此使用3個3x3的卷積替換1個7x7的卷積,如圖b所示,與第1和第2個卷積block的channel=32,stride=2,而最后卷積使用64個輸出通道。

3)ResNet-D
受ResNet-B的啟發(fā),下采樣塊B路徑上的1x1卷積也忽略了輸入feature map的3/4,因此想對其進行修改,這樣就不會忽略任何信息。通過實驗發(fā)現(xiàn),在卷積前增加一個平均為2x2的avg pooling層,將其stride改為1,在實踐中效果很好,同時對計算成本的影響很小。

4、Training Refinements

4.1、Cosine Learning Rate Decay

Loshchilov等人提出了一種余弦退火策略。一種簡化的方法是通過遵循余弦函數(shù)將學習率從初始值降低到0。假設批次總數(shù)為T(忽略預熱階段),那么在批次T時,學習率tm計算為:

可以看出,余弦衰減在開始時緩慢地降低了學習速率,然后在中間幾乎變成線性減少,在結束時再次減緩。與step衰減相比,余弦衰減從一開始就對學習進行衰減,但一直持續(xù)到步進衰減將學習率降低了10倍,從而潛在地提高了訓練進度。

importtorch  
  
optim=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max,eta_min=0,last_epoch=-1)  

4.2、Label Smoothing

對于輸出預測的標簽不可能像真是的label一樣真是,因此這里進行一定的平滑策略,具體的Label Smoothing平滑規(guī)則為:

#-*-coding:utf-8-*-  
  
"""  
qi=1-smoothing(ifi=y)  
qi=smoothing/(self.size-1)(otherwise)#所以默認可以fill這個數(shù),只在i=y的地方執(zhí)行1-smoothing  
另外KLDivLoss和crossentroy的不同是前者有一個常數(shù)  
predict=torch.FloatTensor([[0,0.2,0.7,0.1,0],  
  
[0,0.9,0.2,0.1,0],  
  
[1,0.2,0.7,0.1,0]])  
對應的label為  
tensor([[0.0250,0.0250,0.9000,0.0250,0.0250],  
[0.9000,0.0250,0.0250,0.0250,0.0250],  
[0.0250,0.0250,0.0250,0.9000,0.0250]])  
區(qū)別于one-hot的  
tensor([[0.,0.,1.,0.,0.],  
[1.,0.,0.,0.,0.],  
[0.,1.,0.,0.,0.]])  
"""  
importtorch  
importtorch.nnasnn  
fromtorch.autogradimportVariable  
importmatplotlib.pyplotasplt  
importnumpyasnp  
  
classLabelSmoothing(nn.Module):  
"Implementlabelsmoothing.size表示類別總數(shù)"  
  
def__init__(self,size,smoothing=0.0):  
super(LabelSmoothing,self).__init__()  
self.criterion=nn.KLDivLoss(size_average=False)  
#self.padding_idx=padding_idx  
self.confidence=1.0-smoothing#ifi=y的公式  
self.smoothing=smoothing  
self.size=size  
self.true_dist=None  
  
defforward(self,x,target):  
"""  
x表示輸入(N,M)N個樣本,M表示總類數(shù),每一個類的概率logP  
target表示label(M,)  
"""  
assertx.size(1)==self.size  
true_dist=x.data.clone()#先深復制過來  
#printtrue_dist  
true_dist.fill_(self.smoothing/(self.size-1))#otherwise的公式  
#printtrue_dist  
#變成one-hot編碼,1表示按列填充,  
#target.data.unsqueeze(1)表示索引,confidence表示填充的數(shù)字  
true_dist.scatter_(1,target.data.unsqueeze(1),self.confidence)  
self.true_dist=true_dist  
returnself.criterion(x,Variable(true_dist,requires_grad=False))  
  
if__name__:  
#Exampleoflabelsmoothing.  
  
crit=LabelSmoothing(size=5,smoothing=0.1)  
#predict.shape35  
predict=torch.FloatTensor([[0,0.2,0.7,0.1,0],  
[0,0.9,0.2,0.1,0],  
[1,0.2,0.7,0.1,0]])  
  
v=crit(Variable(predict.log()),  
Variable(torch.LongTensor([2,1,0])))  
#Showthetargetdistributionsexpectedbythesystem.  
plt.imshow(crit.true_dist)  

4.3、Knowledge Distillation

在訓練過程中增加了一個蒸餾損失,以懲罰Teacher模型和Student模型的softmax輸出之間的差異。給定一個輸入,設p為真概率分布,z和r分別為學生模型和教師模型最后全連通層的輸出。損失改進為:

4.4、Mixup Training

在Mixup中,每次我們隨機抽取兩個例子和。然后對這2個sample進行加權線性插值,得到一個新的sample:

其中

importnumpyasnp  
importtorch  
  
  
defmixup_data(x,y,alpha=1.0,use_cuda=True):  
ifalpha>0.:  
lam=np.random.beta(alpha,alpha)  
else:  
lam=1.  
batch_size=x.size()[0]  
ifuse_cuda:  
index=torch.randperm(batch_size).cuda()  
else:  
index=torch.randperm(batch_size)  
  
mixed_x=lam*x+(1-lam)*x[index,:]#自己和打亂的自己進行疊加  
y_a,y_b=y,y[index]  
returnmixed_x,y_a,y_b,lam  
  
defmixup_criterion(y_a,y_b,lam):  
returnlambdacriterion,pred:lam*criterion(pred,y_a)+(1-lam)*criterion(pred,y_b)  

4.5、Experiment Results


審核編輯 黃昊宇

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 人工智能
    +關注

    關注

    1792

    文章

    47514

    瀏覽量

    239236
  • 模型
    +關注

    關注

    1

    文章

    3279

    瀏覽量

    48974
收藏 人收藏

    評論

    相關推薦

    [轉帖]JAVA私塾關于實訓項目的總結

    真實項目,不能是精簡以后的,不能脫離實際應用系統(tǒng)3、在開發(fā)時要和企業(yè)的開發(fā)保持一致4、在做項目的時候不應該有參考代碼 長話短說就是以上幾點,如果你想要更多的了解,可以繼續(xù)往后看
    發(fā)表于 01-03 12:04

    labview軟件自帶項目的問題

    的是labview2012版的,上面有一個連續(xù)采集系統(tǒng)的自帶項目,我現(xiàn)在做的項目就是連續(xù)采集。想用軟件自帶的項目,但是有些模塊用不了,求大神指導啊,已經卡了好幾天了。還有,你們做連
    發(fā)表于 08-14 20:53

    為什么 采集的電壓 不對 如果 ad 直接接轉換的電壓就對了

    為什么 采集的電壓 不對如果 ad 直接接轉換的電壓就對了
    發(fā)表于 08-07 17:46

    畢業(yè)設計遇到瓶頸,求各位大大幫助

    畢業(yè)設計遇到瓶頸,求各位大大幫助! 還有一個月項目時間結束了, 我使用的是NI公司USB6003采集卡,現(xiàn)在遇到一個棘手的問題就是----- 1.如何利用DAQ采集到的模擬輸出信號來
    發(fā)表于 06-01 22:16

    keil在編譯51項目和stm32項目的警告區(qū)別?

    為什么keil在編譯51項目的時候,遇到沒有調用的函數(shù)就會提示WARNING L16但是在編譯stm32項目的時候,遇到沒有調用的函數(shù)就不會有任何提示?
    發(fā)表于 05-25 17:04

    Android學習路上會遇到的各種瓶頸總結

    完全掌握的。 克服了以上瓶頸后,估計實習生也該到了畢業(yè)轉正的時間了,進階路上還有新的瓶頸。新瓶頸有新的玩法:這種玩法需要雙手操作,如果另一
    發(fā)表于 11-13 11:12

    激光振鏡項目的改進

    的根本原因,于是通過飛線實現(xiàn)了電流反饋控制,效果接近原來的參考設備,說明思路走對了。進一步試驗發(fā)現(xiàn),現(xiàn)有的電路,力還不是很大,于是縮小驅動電阻,這個時候導致驅動IC發(fā)熱很高,這個驅動IC,我的是MOS管驅動
    發(fā)表于 10-24 14:34

    IT項目的質量控制

    IT項目的特點IT項目的生命周IT項目管理的重要環(huán)節(jié)IT項目質量控制的基本程序(介紹一個項目質量控制的實例)信息化工程
    發(fā)表于 07-13 00:22 ?0次下載

    Programming Tricks for Higher Conversion Speeds Utilizing De

    PROGRAMMING TRICKS FOR HIGHER CONVERSION SPEEDS UTILIZING DELTA SIGMA CONVERTERS:編程更高的轉換利用Δ-Σ轉換速度把戲
    發(fā)表于 06-01 18:05 ?28次下載

    分布式項目開發(fā)模型Chiefr分析

    項目的目的是在項目成員之間共享和去中心化項目不同部分的開發(fā)和維護。Chiefr的靈感來自于Linux內核及其get_contributors.pl腳本的貢獻
    發(fā)表于 09-28 14:43 ?0次下載
    分布式<b class='flag-5'>項目</b>開發(fā)<b class='flag-5'>模型</b>Chiefr分析

    全年開源項目的盤點和總結

    如果你們這些.NET 開發(fā)者們想要學一點機器學習知識來補充現(xiàn)有的技能,你會怎么做?現(xiàn)在就有一個完美的開源項目可以助你開始實施這一想法!這個完美的開源項目就是微軟的一個
    的頭像 發(fā)表于 01-17 11:18 ?3361次閱讀

    機器學習模型部署到ML項目的過程

    在構建一個大的機器學習系統(tǒng)時,有很多事情需要考慮。但作為數(shù)據(jù)科學家,我們常常只擔心項目的某些部分。
    的頭像 發(fā)表于 05-04 11:56 ?2159次閱讀

    圖像分類任務的各種tricks

    計算機視覺主要問題有圖像分類、目標檢測和圖像分割等。針對圖像分類任務,提升準確率的方法路線有兩條,一個是模型的修改,另一個是各種數(shù)據(jù)處理和訓練的tricks。
    的頭像 發(fā)表于 09-14 16:42 ?1188次閱讀

    物聯(lián)網項目的原因、時間和方式

    環(huán)境清理項目的蜂窩解決方案、農場灌溉的水優(yōu)化解決方案和智能城市照明的改造解決方案有什么共同之處?如果你猜到了“物聯(lián)網技術”,你是對的。如果您猜測這些解決方案都是針對需要解決的昂貴問題而
    的頭像 發(fā)表于 10-13 10:39 ?1649次閱讀
    物聯(lián)網<b class='flag-5'>項目的</b>原因、時間和方式

    肖特基二極管,你真的對了嗎?

    肖特基二極管,你真的對了嗎?
    的頭像 發(fā)表于 12-07 14:27 ?601次閱讀
    肖特基二極管,你真的<b class='flag-5'>用</b><b class='flag-5'>對了</b>嗎?