你有沒有想過,你經常使用的深度學習網絡在看圖像的什么部分進行分類?
例如下圖:
如果深度學習網絡將此圖像分類為“圓號”,你認為圖片的哪個部分對分類最重要?
MathWorks Computer Vision System Toolbox 開發(fā)工程師Birju Patel專注于深度學習,設計了如下案例進行解答這一問題:
我們使用預訓練好的 ResNet-50 網絡進行此實驗。
* He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian. "Deep Residual Learning for Image Recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016.
獲取 MATLAB 中 ResNet-50 網絡的方法是啟動 Add-On Explorer(MATLAB 的 HOME 選項卡)并搜索 resnet。
net = resnet50;
我們需要注意 ResNet-50 需要輸入特定尺寸的圖像。網絡的初始層提供了這一信息:
sz = net.Layers(1).InputSize(1:2)sz = 224 224
所需的圖像尺寸可以直接傳遞給 imresize 函數。)
在網絡中調用 classify ,查看圖片可能的分類:
classify(net,rgb)ans = categorical French horn
ResNet-50 認為這是圓號。
Birju 在一篇關于卷積神經網絡可視化技術的論文中,了解到遮擋敏感性的概念。如果阻擋或遮擋圖像的一部分,將如何影響網絡的預測得分?遮擋不同的部分又將如何影響結果?
Birju 做了如下嘗試:
rgb2 = rgb; rgb2((1:71)+77,(1:71)+108,:) = 128; imshow(rgb2)
classify(net,rgb2)ans = categorical notebook
Hmm...估計網絡“認為”灰色方塊看起來像筆記本。被遮擋的區(qū)域對于圖像分類來說應該很重要。再試試不同的遮擋位置:
rgb3 = rgb;rgb3((1:71)+15,(1:71)+80,:) = 128;imshow(rgb3)
classify(net,rgb3)ans = categorical French horn
好吧,腦袋并不重要。
Birju 編寫了一些 MATLAB 代碼來系統(tǒng)地量化不同圖像區(qū)域對分類結果的相對重要性。他使用 MATLAB 構建了大量圖像,并對遮擋不同區(qū)域的圖像進行批處理。對于遮擋的不同位置,記錄預期類(本例為“法國號”)的概率得分。
我們制作一批帶有 71x71 遮擋區(qū)域的圖像。首先計算所有遮擋模塊的頂點,用 (X1,Y1) 和 (X2,Y2) 表示。
mask_size = [71 71]; [H,W,~] = size(rgb); X = 1:W; Y = 1:H; [X1, Y1] = meshgrid(X, Y); X1 = X1(:) - (mask_size(2)-1)/2; Y1 = Y1(:) - (mask_size(1)-1)/2; X2 = X1 + mask_size(2) - 1; Y2 = Y1 + mask_size(1) - 1;
注意不要讓遮擋區(qū)域的頂點偏離圖像邊界。
X1 = max(1, X1); Y1 = max(1, Y1); X2 = min(W, X2); Y2 = min(H, Y2);
批處理:
batch = repmat(rgb,[1 1 1 size(X1,1)]); for i = 1:size(X1,1) c = X1(i):X2(i); r = Y1(i):Y2(i); batch(r,c,:,i) = 128; % gray mask. end
注意:這一批包含 50,000 多張圖像。你需要大量的 RAM 才能同時創(chuàng)建和處理如此大量的圖像。
這里有一些遮擋的圖像:
現在,我們將使用 predict(而不是 classify)來獲取每個圖像在每個類別中的預測分數。MiniBatchSize 參數是用來限制 GPU 內存的使用,意味著 predict 函數將一次發(fā)送 64 個圖像到 GPU 進行處理。
s = predict(net, batch, 'MiniBatchSize',64);size(s)ans = 50176 1000
我們獲得了很多的概率得分!其中 51,529 個圖像,共有 1,000 個類別。矩陣 s 具有每個類別和每個圖像的預測分數。
我們重點關注預測原始圖像類別的預測分數:
scores = predict(net,rgb); [~,horn_idx] = max(scores);
這里是每一個圓號類別中的圖像預測分數:
s_horn = s(:,horn_idx);
將圓號類別的分數轉換為圖像顯示:
S_horn = reshape(s_horn,H,W); imshow(-S_horn,[]) colormap(gca,'parula')
最亮的區(qū)域表示遮擋對概率得分影響最大的遮擋區(qū)間。
下面我們找到了最影響圓號概率得分的遮擋位置:
[min_score,min_idx] = min(s_horn); rgb_min_score = batch(:,:,:,min_idx); imshow(rgb_min_score)
結果可見,識別圓號的關鍵在于螺旋形管身和閥鍵,而不是號嘴。
-
gpu
+關注
關注
28文章
4760瀏覽量
129127 -
圖像分類
+關注
關注
0文章
90瀏覽量
11942 -
深度學習
+關注
關注
73文章
5510瀏覽量
121329
發(fā)布評論請先 登錄
相關推薦
評論