目錄
簡要介紹PyTorch、張量和NumPy
為什么選擇卷積神經(jīng)網(wǎng)絡(luò)(CNNs)?
識別服裝問題
使用PyTorch實現(xiàn)CNNs
1.簡要介紹PyTorch、張量和NumPy
讓我們快速回顧一下第一篇文章中涉及的內(nèi)容。我們討論了PyTorch和張量的基礎(chǔ)知識,還討論了PyTorch與NumPy的相似之處。
PyTorch是一個基于python的庫,提供了以下功能:
用于創(chuàng)建可序列化和可優(yōu)化模型的TorchScript
以分布式訓(xùn)練進行并行化計算
動態(tài)計算圖,等等
PyTorch中的張量類似于NumPy的n維數(shù)組,也可以與gpu一起使用。在這些張量上執(zhí)行操作幾乎與在NumPy數(shù)組上執(zhí)行操作類似。這使得PyTorch非常易于使用和學(xué)習(xí)。
在本系列的第1部分中,我們構(gòu)建了一個簡單的神經(jīng)網(wǎng)絡(luò)來解決一個案例研究。使用我們的簡單模型,我們在測試集中獲得了大約65%的基準準確度?,F(xiàn)在,我們將嘗試使用卷積神經(jīng)網(wǎng)絡(luò)來提高這個準確度。
2.為什么選擇卷積神經(jīng)網(wǎng)絡(luò)(CNNs)?
在我們進入實現(xiàn)部分之前,讓我們快速地看看為什么我們首先需要CNNs,以及它們是如何工作的。
我們可以將卷積神經(jīng)網(wǎng)絡(luò)(CNNs)看作是幫助從圖像中提取特征的特征提取器。
在一個簡單的神經(jīng)網(wǎng)絡(luò)中,我們把一個三維圖像轉(zhuǎn)換成一維圖像,對吧?讓我們看一個例子來理解這一點:
你能認出上面的圖像嗎?這似乎說不通?,F(xiàn)在,讓我們看看下面的圖片:
我們現(xiàn)在可以很容易地說,這是一只狗。如果我告訴你這兩個圖像是一樣的呢?相信我,他們是一樣的!唯一的區(qū)別是第一個圖像是一維的,而第二個圖像是相同圖像的二維表示
空間定位
人工神經(jīng)網(wǎng)絡(luò)也會丟失圖像的空間方向。讓我們再舉個例子來理解一下:
你能分辨出這兩幅圖像的區(qū)別嗎?至少我不能。由于這是一個一維的表示,因此很難確定它們之間的區(qū)別?,F(xiàn)在,讓我們看看這些圖像的二維表示:
在這里,圖像某些定位已經(jīng)改變,但我們無法通過查看一維表示來識別它。
這就是人工神經(jīng)網(wǎng)絡(luò)的問題——它們失去了空間定位。
大量參數(shù)
神經(jīng)網(wǎng)絡(luò)的另一個問題是參數(shù)太多。假設(shè)我們的圖像大小是28283 -所以這里的參數(shù)是2352。如果我們有一個大小為2242243的圖像呢?這里的參數(shù)數(shù)量為150,528。
這些參數(shù)只會隨著隱藏層的增加而增加。因此,使用人工神經(jīng)網(wǎng)絡(luò)的兩個主要缺點是:
丟失圖像的空間方向
參數(shù)的數(shù)量急劇增加
那么我們?nèi)绾翁幚磉@個問題呢?如何在保持空間方向的同時減少可學(xué)習(xí)參數(shù)?
這就是卷積神經(jīng)網(wǎng)絡(luò)真正有用的地方。CNNs有助于從圖像中提取特征,這可能有助于對圖像中的目標進行分類。它首先從圖像中提取低維特征(如邊緣),然后提取一些高維特征(如形狀)。
我們使用濾波器從圖像中提取特征,并使用池技術(shù)來減少可學(xué)習(xí)參數(shù)的數(shù)量。
在本文中,我們不會深入討論這些主題的細節(jié)。如果你希望了解濾波器如何幫助提取特征和池的工作方式,我強烈建議你從頭開始學(xué)習(xí)卷積神經(jīng)網(wǎng)絡(luò)的全面教程。
3.問題:識別服裝
理論部分已經(jīng)鋪墊完了,開始寫代碼吧。我們將討論與第一篇文章相同的問題陳述。這是因為我們可以直接將我們的CNN模型的性能與我們在那里建立的簡單神經(jīng)網(wǎng)絡(luò)進行比較。
你可以從這里下載“識別”Apparels問題的數(shù)據(jù)集。
https://datahack.analyticsvidhya.com/contest/practice-problem-identify-the-apparels/?utmsource=blog&utmmedium=building-image-classification-models-cnn-pytorch
讓我快速總結(jié)一下問題陳述。我們的任務(wù)是通過觀察各種服裝形象來識別服裝的類型。我們總共有10個類可以對服裝的圖像進行分類:
數(shù)據(jù)集共包含70,000張圖像。其中60000張屬于訓(xùn)練集,其余10000張屬于測試集。所有的圖像都是大小(28*28)的灰度圖像。數(shù)據(jù)集包含兩個文件夾,一個用于訓(xùn)練集,另一個用于測試集。每個文件夾中都有一個.csv文件,該文件具有圖像的id和相應(yīng)的標簽;
準備好開始了嗎?我們將首先導(dǎo)入所需的庫:
加載數(shù)據(jù)集
現(xiàn)在,讓我們加載數(shù)據(jù)集,包括訓(xùn)練,測試樣本:
該訓(xùn)練文件包含每個圖像的id及其對應(yīng)的標簽
另一方面,測試文件只有id,我們必須預(yù)測它們對應(yīng)的標簽
樣例提交文件將告訴我們預(yù)測的格式
我們將一個接一個地讀取所有圖像,并將它們堆疊成一個數(shù)組。我們還將圖像的像素值除以255,使圖像的像素值在[0,1]范圍內(nèi)。這一步有助于優(yōu)化模型的性能。
讓我們來加載圖像:
如你所見,我們在訓(xùn)練集中有60,000張大小(28,28)的圖像。由于圖像是灰度格式的,我們只有一個單一通道,因此形狀為(28,28)。
現(xiàn)在讓我們研究數(shù)據(jù)和可視化一些圖像:
以下是來自數(shù)據(jù)集的一些示例。我鼓勵你去探索更多,想象其他的圖像。接下來,我們將把圖像分成訓(xùn)練集和驗證集。
創(chuàng)建驗證集并對圖像進行預(yù)處理
我們在驗證集中保留了10%的數(shù)據(jù),在訓(xùn)練集中保留了10%的數(shù)據(jù)。接下來將圖片和目標轉(zhuǎn)換成torch格式:
同樣,我們將轉(zhuǎn)換驗證圖像:
我們的數(shù)據(jù)現(xiàn)在已經(jīng)準備好了。最后,是時候創(chuàng)建我們的CNN模型了!
4.使用PyTorch實現(xiàn)CNNs
我們將使用一個非常簡單的CNN架構(gòu),只有兩個卷積層來提取圖像的特征。然后,我們將使用一個完全連接的Dense層將這些特征分類到各自的類別中。
讓我們定義一下架構(gòu):
現(xiàn)在我們調(diào)用這個模型,定義優(yōu)化器和模型的損失函數(shù):
這是模型的架構(gòu)。我們有兩個卷積層和一個線性層。接下來,我們將定義一個函數(shù)來訓(xùn)練模型:
最后,我們將對模型進行25個epoch的訓(xùn)練,并存儲訓(xùn)練和驗證損失:
可以看出,隨著epoch的增加,驗證損失逐漸減小。讓我們通過繪圖來可視化訓(xùn)練和驗證的損失:
啊,我喜歡想象的力量。我們可以清楚地看到,訓(xùn)練和驗證損失是同步的。這是一個好跡象,因為模型在驗證集上進行了很好的泛化。
讓我們在訓(xùn)練和驗證集上檢查模型的準確性:
訓(xùn)練集的準確率約為72%,相當不錯。讓我們檢查驗證集的準確性:
正如我們看到的損失,準確度也是同步的-我們在驗證集得到了72%的準確度。
為測試集生成預(yù)測
最后是時候為測試集生成預(yù)測了。我們將加載測試集中的所有圖像,執(zhí)行與訓(xùn)練集相同的預(yù)處理步驟,最后生成預(yù)測。
所以,讓我們開始加載測試圖像:
現(xiàn)在,我們將對這些圖像進行預(yù)處理步驟,類似于我們之前對訓(xùn)練圖像所做的:
最后,我們將生成對測試集的預(yù)測:
用預(yù)測替換樣本提交文件中的標簽,最后保存文件并提交到排行榜:
你將在當前目錄中看到一個名為submission.csv的文件。你只需要把它上傳到問題頁面的解決方案檢查器上,它就會生成分數(shù)。鏈接:https://datahack.analyticsvidhya.com/contest/practice-problem-identify-the-apparels/?utmsource=blog&utmmedium=building-image-classification-models-cnn-pytorch
我們的CNN模型在測試集上給出了大約71%的準確率,這與我們在上一篇文章中使用簡單的神經(jīng)網(wǎng)絡(luò)得到的65%的準確率相比是一個很大的進步。
5.結(jié)尾
在這篇文章中,我們研究了CNNs是如何從圖像中提取特征的。他們幫助我們將之前的神經(jīng)網(wǎng)絡(luò)模型的準確率從65%提高到71%,這是一個重大的進步。
你可以嘗試使用CNN模型的超參數(shù),并嘗試進一步提高準確性。要調(diào)優(yōu)的超參數(shù)可以是卷積層的數(shù)量、每個卷積層的濾波器數(shù)量、epoch的數(shù)量、全連接層的數(shù)量、每個全連接層的隱藏單元的數(shù)量等。
-
python
+關(guān)注
關(guān)注
56文章
4805瀏覽量
84922 -
pytorch
+關(guān)注
關(guān)注
2文章
808瀏覽量
13316
發(fā)布評論請先 登錄
相關(guān)推薦
評論