0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

半小時學會PyTorch快速圖片分類

DPVg_AI_era ? 來源:lq ? 2019-07-13 07:57 ? 次閱讀

通過本教程,讀者將能夠在選擇的任何圖像數據集上,構建和訓練圖像識別器,同時充分了解底層模型架構和訓練過程。教程內容包括數據提取、數據可視化、CNN、ResNets、遷移學習、結果解釋、微調等。

這是一篇長文教程,建議大家讀不完的話一定要收藏,利用閑暇時光將其讀完!更加歡迎將本文轉發(fā)給同學、朋友、同事等。

本文的目標是能夠讓你可以在任何圖像數據集上構建和訓練圖像識別器,同時充分了解底層模型架構和培訓過程。

目標讀者:任何研究圖像識別、或對此領域感興趣的初學者

教程目錄:

數據提取

數據可視化

模型訓練

結果解釋

模型層的凍結和解凍

微調

教程所使用的Jupyter notebook:

https://github.com/SalChem/Fastai-iNotes-iTutorials/blob/master/Image_Recognition_Basics.ipynb

更簡單直接的方式是登錄Google Colab:

https://colab.research.google.com/github/SalChem/Fastai-iNotes-iTutorials/blob/master/Image_Recognition_Basics.ipynb

注意:使用Google Colab之前,確保你做了如下設置

Runtime -> Change runtime type -> Hardware Accelerator -> GPU

設置IPython內核并初始化

加載依賴庫

初始化

其中,bs 代表batch size,意為每次送入模型的訓練圖像的數量。每次batch迭代后都會更新模型參數。

比如我們有640個圖像,那么bs=64;參數將在1 epoch的過程中更新10次。

如果你運行教程過程中提示內存不足,可以使用較小的bs,按照2的倍數增減即可。

使用特定值初始化上面的偽隨機數生成器可使系統穩(wěn)定,從而產生可重現的結果。

數據提取

數據集來自Oxford-IIIT Pet Dataset,可以使用fastai數據集對模塊進行檢索。

URLs.PETS 是數據集的url。這里提供了12個品種的貓和25個品種的狗。untar_data 解壓并下載數據文件到 path。

PosixPath('/home/jupyter/.fastai/data/oxford-iiit-pet/images/scottish_terrier_119.jpg')

每個圖像的標簽都包含在圖像文件名中,需要使用正則表達式提取。模式如下:

創(chuàng)建訓練并驗證數據集:

ImageDataBunch 根據路徑 path_img 中的圖像創(chuàng)建訓練數據集 train_ds 和驗證數據集 valid_ds。

from_name_re 使用在編譯表達式模式 pat 后獲得的正則表達式從文件名 fnames 列表中獲取標簽。

df_tfms 是即時應用于圖像的轉換。在這里,圖像將調整為 224x224,居中,裁剪和縮放。

這種轉換是數據增強的實例,不會更改圖像內部的內容,但會更改其像素值以獲得更好的模型概括。

normalize 使用ImageNet圖像的標準偏差和平均值對數據進行標準化。

數據可視化

訓練數據樣本表示為

(Image (3, 224, 224), Category scottish_terrier)

Image里是RGB數值,Category 是圖像標簽。對應的圖像如下:

len(data.train_ds)和len(data.valid_ds)分別輸出訓練樣本5912和驗證樣本數量1478。

data.c和data.classes分別輸出類及其標簽的數量。下面的標簽共有37個類別:

['Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair', 'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue', 'Siamese', 'Sphynx', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle','boxer', 'chihuahua', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher', 'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']

show_batch 顯示一些batch里的圖片。

模型訓練

cnn_learner 使用來自給定架構的預訓練模型構建CNN學習器、來自預訓練模型的學習參數用于初始化模型,允許更快的收斂和高精度。我們使用的CNN架構是ResNet34。下圖是一個典型的CNN架構。

ResNet34后面的數字可以隨意更改,比如改成ResNet50。數字越大,GPU內存消耗越高。

讓我們繼續(xù),現在可以在數據集上訓練模型了!

fit_one_cycle會按預設epoch數訓練模型,比如4個epoch。

epoch數表示模型查看整個圖像集的次數。但是,在每個epoch中,隨著數據的增加,同一張圖像都會與上個epoch略有不同。

通常,度量誤差將隨著epoch的增加而下降。只要驗證集的精度不斷提高,增加epoch數量就是個好辦法。然而,epoch過多可能導致模型學習了特定的圖像,而不是一般的類,要避免這種情況出現。

剛才提到的訓練就是我們所說的“特征提取”,所以只對模型的頭部(最底下的幾層)的參數進行了更新。接下來將嘗試對全部層的參數進行微調。

恭喜!模型已成功訓練,可以識別貓和狗了。識別準確率大約是93.5%。

還能進步嗎?這要等到微調之后了。

我們保存當前的模型參數,以便重新加載時使用。

對預測結果的解釋

現在我們看看如何正確解釋當前的模型結果。

ClassificationInterpretation提供錯誤分類圖像的可視化實現。

plot_top_losses顯示最高損失的圖像及其:預測標簽/實際標簽/損失/實際圖像類別的概率

高損失意味著對錯誤答案出現的高信度。繪制最高損失是可視化和解釋分類結果的好方法。

具有最高損失的錯誤分類圖像

分類混淆矩陣

在混淆矩陣中,對角線元素表示預測標簽與真實標簽相同的圖像的數量,而非對角線元素是由分類器錯誤標記的元素。

most_confused只突出顯示預測分類和實際類別中最混亂的組合,換句話說,就是分類最常出錯的那些組合。從圖中可以看到,模型經常將斯塔??ざ放He誤分類為美國斗牛犬,它們實際上看起來非常像。

[('Siamese', 'Birman', 6), ('american_pit_bull_terrier', 'staffordshire_bull_terrier', 5), ('staffordshire_bull_terrier', 'american_pit_bull_terrier', 5), ('Maine_Coon', 'Ragdoll', 4), ('beagle', 'basset_hound', 4), ('chihuahua', 'miniature_pinscher', 3), ('staffordshire_bull_terrier', 'american_bulldog', 3), ('Birman', 'Ragdoll', 2), ('British_Shorthair', 'Russian_Blue', 2), ('Egyptian_Mau', 'Abyssinian', 2), ('Ragdoll', 'Birman', 2), ('american_bulldog', 'staffordshire_bull_terrier', 2), ('boxer', 'american_pit_bull_terrier', 2), ('chihuahua', 'shiba_inu', 2), ('miniature_pinscher', 'american_pit_bull_terrier', 2), ('yorkshire_terrier', 'havanese', 2)]

網絡層的凍結和解凍

在默認情況下,在fastai中,使用預訓練的模型對較早期的層進行凍結,使網絡只能更改最后一層的參數,如上所述。凍結第一層,僅訓練較深的網絡層可以顯著降低計算量。

我們總是可以調用unfreeze函數來訓練所有網絡層,然后再使用fit或fit_one_cycle。這就是所謂的“微調”,這是在調整整個網絡的參數。

現在的準確度比以前略差。這是為什么?

這是因為我們以相同的速度更新了所有層的參數,這不是我們想要的,因為第一層不需要像最后一層那樣需要做太多變動??刂茩嘀馗铝康某瑓捣Q為“學習率”,也叫步長。它可以根據損失的梯度調整權重,目的是減少損失。例如,在最常見的梯度下降優(yōu)化器中,權重和學習率之間的關系如下:

順便說一下,梯度只是一個向量,它是導數在多變量領域的推廣。

因此,對模型進行微調的更好方法是對較低層和較高層使用不同的學習率,通常稱為差異或判別學習率。

本教程中可以互換使用參數和權重。更準確地說,參數是權重和偏差。但請注意,超參數和參數不一樣,超參數無法在訓練中進行估計。

對預測模型的微調

為了找到最適合微調模型的學習率,我們使用學習速率查找器,可以逐漸增大學習速率,并且在每個batch之后記錄相應的損失。在fastai庫通過lr_find來實現。

首先加載之前保存的模型,并運行l(wèi)r_find

recorder.plot可用于繪制損失與學習率的關系圖。當損失開始發(fā)散時,停止運行。

從得到的圖中,我們一致認為適當的學習率約為1e-4或更小,超過這個范圍,損失就開始增大并失去控制。我們將最后一層的學習速率設為1e-4,更早期的層設為1e-6。同樣,這是因為早期的層已經訓練得很好了,用來捕獲通用特征,不需要那么頻繁的更新。

我們之前的實驗中使用的學習率為0.003,這是該庫的默認設置。

在我們使用這些判別性學習率訓練我們的模型之前,讓我們揭開fit_one_cycle和fitmethods之間的差異,因為兩者都是訓練模型的合理選擇。這個討論對于理解訓練過程非常有價值,但可以直接跳到結果。

fit_one_cycle vs fit:

簡而言之,二者之間不同之處在于fit_one_cycle實現了Leslie Smith 循環(huán)策略,而沒有使用固定或逐步降低的學習率來更新網絡的參數,而是在兩個合理的較低和較高學習速率范圍之間振蕩。

訓練中的學習率超參數

在微調深度神經網絡時,良好的學習率超參數是至關重要的。使用較高的學習率可以讓網絡更快地學習,但是學習率太高可能使模型無法收斂。另一方面,學習率太小會使訓練速度過于緩慢。

不同水平的學習率對模型收斂性的影響

在本文的實例中,我們通過查看不同學習率下記錄的損失,估算出合適的學習率。在更新網絡參數時,可以將此學習率作為固定學習率。換句話說,就是對所有訓練迭代使用相同的學習率,可以使用learn.fit來實現。一種更好的方法是,隨著訓練的進行逐步改變學習率。有兩種方法可以實現,即學習率規(guī)劃(設定基于時間的衰減,逐步衰減,指數衰減等),以及自適應學習速率法(Adagrad,RMSprop,Adam等)。

簡單的1cycle策略

1cycle策略是一種學習率調度器,讓學習率在合理的最小和最大邊界之間振蕩。制定這兩個邊界有什么價值呢?上限是我們從學習速率查找器獲得的,而最小界限可以小到上限的十分之一。這種方法的優(yōu)點是可以克服局部最小值和鞍點,這些點是平坦表面上的點,通常梯度很小。事實證明,1cycle策略比其他調度或自適應學習方法更快、更準確。Fastai在fit_one_cycle中實現了cycle策略,在內部調用固定學習率方法和OneCycleScheduler回調。

1cycle的一個周期長度

下圖顯示了超收斂方法如何在Cifar-10的迭代次數更少的情況下達到比典型(分段常數)訓練方式更高的精度,兩者都使用56層殘余網絡架構。

超收斂精度測試與Cifar-10上具有相同架構模型的典型訓練機制

揭曉真相的時刻到了

在選擇了網絡層的判別學習率之后,就可以解凍模型,并進行相應的訓練了。

Slice函數將網絡的最后一層學習率設為1e-4,將第一層學習率設為1e-6。中間各層在此范圍內以相等的增量設定學習率。

結果,預測準確度有所提升,但提升的并不多,我們想知道,這時是否需要對模型進行微調?

在微調任何模型之前始終要考慮的兩個關鍵因素就是數據集的大小及其與預訓練模型的數據集的相似性。在我們的例子中,我們使用“寵物”數據集類似于ImageNet中的圖像,數據集相對較小,所以我們從一開始就實現了高分類精度,而沒有對整個網絡進行微調。

盡管如此,我們仍然能夠對精度結果進行改進,并從中學到很多東西。

下圖說明了使用和微調預訓練模型的三種合理方法。在本教程中,我們嘗試了第一個和第三個策略。第二個策略在數據集較小,但與預訓練模型的數據集不同,或者數據集較大,但與預訓練模型的數據集相似的情況下也很常見。

在預訓練模型上微調策略

恭喜,我們已經成功地使用最先進的CNN覆蓋了圖像分類任務,網絡的基礎結構和訓練過程都打下了堅實的基礎。

至此,你已經可以自己的數據集上構建圖像識別器了。如果你覺得還沒有準備好,可以從Google Image抓取一部分圖片組成自己的數據集。

開始體驗吧!

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規(guī)問題,請聯系本站處理。 舉報投訴
  • 數據集
    +關注

    關注

    4

    文章

    1208

    瀏覽量

    24754
  • cnn
    cnn
    +關注

    關注

    3

    文章

    353

    瀏覽量

    22268
  • pytorch
    +關注

    關注

    2

    文章

    808

    瀏覽量

    13292

原文標題:從零開始,半小時學會PyTorch快速圖片分類

文章出處:【微信號:AI_era,微信公眾號:新智元】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    PMP10215工作半小時溫度就高達127.3度,這樣怎么做實際應用呢?

    按照PMP10215 Rev_D Test Results上的說法,PMP10215工作半小時溫度就高達127.3度,這樣怎么做實際應用呢?這么高的溫度不是會很快使變壓器老化,不能正常工作了嗎?
    發(fā)表于 10-18 06:34

    半小時的不眠不休,終于搞定~~~

    經過半小時的不眠不休,終于是搞 定了這個SDT的兼容性問題,究其緣由,則是由于自己新建元器件庫的時候太不規(guī)則了~~以后注意注意~~雖說創(chuàng)新,但得守規(guī)則~~~{:4_95:}
    發(fā)表于 12-06 09:41

    如何快速學會AD?

    最近看完了AD視頻教程,怎么感覺一點都沒用呢?求教大家,如何才能快速入手AD,學會畫板子?(是不是方法有問題,感覺學的很迷茫?。。?/div>
    發(fā)表于 08-15 09:36

    寫個單片機小程序。按鍵每按一次,時間增加半小時 c51

    求大神幫助,寫個單片機小程序。按鍵每按一次,時間增加半小時。在線等。。。
    發(fā)表于 07-17 20:14

    labview2014中在一個while循環(huán)里調用dll運行半小時后就崩潰了該怎么解決?

    我嘗試寫一個新的dll,放在while循環(huán)里運行半小時又崩潰了,但是調用window自帶的dll運行多久都沒事,是不是我寫的dll運行時會在labview里面產生一些緩存,運行半小時緩存滿了
    發(fā)表于 09-06 09:45

    半小時開發(fā)基于 STM32 的室內智能環(huán)境監(jiān)測儀

    半小時開發(fā)基于 STM32 的室內智能環(huán)境監(jiān)測儀
    發(fā)表于 09-06 22:28

    1小時學會C語言(51單片機)

    1小時學會C語言(51單片機)
    發(fā)表于 03-04 09:43

    PyTorch10的基礎教程

    PyTorch 10 基礎教程(4):訓練分類
    發(fā)表于 06-05 17:42

    6小時學會labview

    6小時學會labview, LabVIEW Six Hour Course – Instructor Notes  
    發(fā)表于 08-02 13:52 ?33次下載

    PyTorch官網教程PyTorch深度學習:60分鐘快速入門中文翻譯版

    PyTorch 深度學習:60分鐘快速入門”為 PyTorch 官網教程,網上已經有部分翻譯作品,隨著PyTorch1.0 版本的公布,這個教程有較大的代碼改動,本人對教程進行重新翻
    的頭像 發(fā)表于 01-13 11:53 ?1w次閱讀

    textCNN論文與原理——短文本分類

    是處理圖片的torchvision,而處理文本的少有提及,快速處理文本數據的包也是有的,那就是torchtext[1]。下面還是結合上一個案例:【深度學習】textCNN論文與原理——短文本分類(基于
    的頭像 發(fā)表于 12-31 10:08 ?2548次閱讀
    textCNN論文與原理——短文本<b class='flag-5'>分類</b>

    10小時輕松學會C語言及其編程

    10小時輕松學會C語言及其編程
    發(fā)表于 03-30 15:43 ?15次下載
    10<b class='flag-5'>小時</b>輕松<b class='flag-5'>學會</b>C語言及其編程

    PyTorch教程4.2之圖像分類數據集

    電子發(fā)燒友網站提供《PyTorch教程4.2之圖像分類數據集.pdf》資料免費下載
    發(fā)表于 06-05 15:41 ?0次下載
    <b class='flag-5'>PyTorch</b>教程4.2之圖像<b class='flag-5'>分類</b>數據集

    PyTorch教程4.3之基本分類模型

    電子發(fā)燒友網站提供《PyTorch教程4.3之基本分類模型.pdf》資料免費下載
    發(fā)表于 06-05 15:43 ?0次下載
    <b class='flag-5'>PyTorch</b>教程4.3之基本<b class='flag-5'>分類</b>模型

    PyTorch教程4.6之分類中的泛化

    電子發(fā)燒友網站提供《PyTorch教程4.6之分類中的泛化.pdf》資料免費下載
    發(fā)表于 06-05 15:39 ?0次下載
    <b class='flag-5'>PyTorch</b>教程4.6之<b class='flag-5'>分類</b>中的泛化