編者按:本文作者Sara Robinson在Medium上發(fā)布了一個有趣的項目,她自制了一款APP,能自動識別歌手Taylor Swift。這與我們之前介紹的尋找威利項目很像。該教程非常詳細,有興趣的同學可以學習一下,動手做一個自己的圖像識別工具哦~本文已獲作者授權,以下是對原文的編譯。
注:由于寫作本文時TensorFlow沒有Swift庫,我用Swift構建了針對我的模型的預測請求的APP。
以下就是我們創(chuàng)建的APP:
TensorFlow物體檢測API能讓你識別出一張圖片中特定物體的位置,這可以應用到許多有趣的程序上。不過我平常拍人比較多,所以就想把這一技術應用到人臉識別上。結果發(fā)現模型表現得非常好!也就是上圖我創(chuàng)建的Taylor Swift檢測器。
本文將列出模型的構建步驟,從收集Taylor Swift的照片到模型的訓練:
對圖像進行預處理,改變大小、貼標簽、將它們分成訓練和測試兩部分,并修改成Pascal VOC格式;
將圖片轉化成TFRecords文件以符合物體檢測API;
利用MobileNet在谷歌Cloud ML Engine上訓練模型;
導出訓練好的模型并將其部署到ML Engine上進行服務;
構建一個iOS前端,根據訓練好的模型做出預測請求(使用Swift)。
下面是各部分如何結合在一起的架構圖:
在開始之前,首先要解釋一下我們即將用到的技術和術語:TensorFlow物體檢測API是一個構建在TensorFlow上的框架,用于識別圖像中特定的對象。例如,你可以用很多貓的照片訓練它,一旦訓練完畢,你可以輸入一張貓的圖像,它就會輸出一個方框列表,認為圖像中有一只貓。雖然它的名字中含有API,但是你可以將它更多地想象成用于遷移學習的一套便利的工具。
但是,訓練模型識別圖像中的對象是個費時費力的活。物體檢測最酷的地方就是它支持五個預訓練模型的遷移學習(transfer learning),那么什么是遷移學習呢?比如,當兒童學習第一門語言時,他們會接觸大量的例子,如果有錯就會立刻被糾正過來。例如當孩子們學習識別貓時,他們的父母會指著圖片上的貓,并說出“貓”這個詞,這種重復增強了他們的腦回路。當它們學習如何識別一只狗時,無需從頭開始,這一過程與貓的識別類似,只是學習對象不同。這就是遷移學習的工作原理。
但我沒有時間尋找并標記數千個Taylor Swift的圖像,但是我可以通過修改最后幾個圖層、在數百萬張圖像上訓練的模型中提取特征,應用于TSwift的檢測。
第一步:預處理圖像
首先要感謝Dat Tran寫的關于浣熊檢測器的博客,地址:https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9
首先,我從谷歌圖片上下載了200張Taylor Swift的照片,這里安利一個Chrome插件:Fatkun Batch Download Image,可以下載所有圖片搜索結果。在打標簽之前,我把圖片分為兩類:訓練和測試。另外,我寫了一個調整圖片大小的腳本(https://github.com/sararob/tswift-detection/blob/master/resize.py),確保每張圖的寬度不超過600px。
由于檢測器會告訴我們圖中的對象位置,所以你不能直接把圖像和標簽作為訓練數據。你需要用邊框將對象圈出來,以及將表框打上標簽(在我們的數據集中,只需要一個標簽tswift)。
打邊框工具依然使用LabelImg,這是一個基于Python的程序,你只需輸入帶標簽的圖像,它就會輸出一個xml文件,將每張照片都打上邊框同時還有相關標簽(不到一上午我就處理好200張圖片了)。下面是它如何工作的(標簽輸入為tswift):
然后LabelImg生成一個xml文件:
現在我有了一張帶有邊框和標簽的圖片了,但是我還要把它轉換成TensorFlow可接受的方式——一個數據的二進制表示TFRecord。關于這一方法可以在GitHub上查看。要運行我的腳本,你需要先下載一個tensorflow/models,從tensorflow/models/research本地直接運行腳本,帶上以下參數(運行兩次:一次用于訓練數據,一次用于測試數據)
python convert_labels_to_tfrecords.py
--output_path=train.record
--images_dir=path/to/your/training/images/
--labels_dir=path/to/training/label/xml/
第二步:訓練檢測器
我可以在筆記本電腦上訓練這個模型,但是時間會很長,而且占用大量的資源。并且一旦我需要用電腦做別的事,訓練就會中斷。所以,我選擇了云!我們可以利用云來運行多個跨核心的訓練,幾個小時內就能完成整個工作,并且用Cloud ML engine的速度比GPU還要快。
設置Cloud ML Engine
我準備將所有TFRecord格式的數據上傳到云并開始訓練。首先,我在谷歌云端控制臺中創(chuàng)建了一個項目,并啟用了Cloud ML Engine:
然后,我將創(chuàng)建一個云存儲bucket來打包模型的所有資源。確保在指定區(qū)域進行存儲(不要選擇多個區(qū)域):
我將在這個bucket中/data子目錄來放置訓練和測試TFRecord的文件:
目標對象檢測API還需要一個將標簽映射到整數ID的pbtxt文件。由于我們只有一個標簽,這個是非常短的:
item {
id: 1
name: 'tswift'
}
添加MobileNet檢查點進行遷移學習
因為我并非從零開始訓練這個模型,所以當我運行訓練時,我需要指向我將要建立的預訓練模型。我選擇使用MobileNet模型——它是針對移動設備優(yōu)化的一系列小模型。雖然我不會直接在移動設備上訓練模型,但MobileNet將會快速訓練,并允許更快的預測請求。我下載了這個MobileNet檢查點用于訓練,檢查點是一個二進制文件,包含訓練過程中特定點的TensorFlow模型的狀態(tài)。下載并解壓縮后,你可以看到它包含的三個文件:
以上所有都要用來訓練模型,所以我將它們放在云存儲bucket中的同一個data/目錄中。
在開始訓練之前,還需要添加一個文件。對象檢測腳本需要一種方法查找模型的檢查點、標簽映射和訓練數據。我們將用配置文件處理這一點。TF對象檢測為五個預訓練模型采集了樣本配置文件。我們在這里為MobileNet使用一個,并且在云存儲bucket的相應路徑中更新了所有PATH_TO_BE_CONFIGURED占位符。除了將我的模型連接到云存儲中的數據外,此文件還為我的模型配置了幾個超參數,如卷積大小、激活函數和步驟。
以下是開始訓練之前云存儲bucket中我的/data中的所有文件:
我還會在bucket中創(chuàng)建train/和eval/子目錄——這是TensorFlow在訓練和評估時書寫模型檢查點文件的地方。
現在已經準備好訓練了,通過執(zhí)行gcloud命令開始。請注意,你需要在本地復制tensorflow/models/research并從該目錄運行此訓練腳本:
# Run this script from tensorflow/models/research:
gcloud ml-engine jobs submit training ${YOUR_TRAINING_JOB_NAME}
--job-dir=${YOUR_GCS_BUCKET}/train
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
--module-name object_detection.train
--region us-central1
--config object_detection/samples/cloud/cloud.yml
--runtime-version=1.4
--
--train_dir=${YOUR_GCS_BUCKET}/train
--pipeline_config_path=${YOUR_GCS_BUCKET}/data/ssd_mobilenet_v1_coco.config
訓練的同時,我也開始了評估工作。我會使用之前從未見過的數據來評估模型的準確性:
# Run this script from tensorflow/models/research:
gcloud ml-engine jobs submit training ${YOUR_EVAL_JOB_NAME}
--job-dir=${YOUR_GCS_BUCKET}/train
--packages dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
--module-name object_detection.eval
--region us-central1
--scale-tier BASIC_GPU
--runtime-version=1.4
--
--checkpoint_dir=${YOUR_GCS_BUCKET}/train
--eval_dir=${YOUR_GCS_BUCKET}/eval
--pipeline_config_path=${YOUR_GCS_BUCKET}/data/ssd_mobilenet_v1_coco.config
你可以通過在云端控制臺導航到ML Engine的“作業(yè)”部分來驗證您的任務是否正確運行,并檢查日志以查找特定作業(yè):
第三步:部署預測模型
為了將模型部署到ML Engine,我需要將模型檢查點轉換為ProtoBuf。在我的train/bucket中,可以看到從幾處保留的檢查點文件:
文件的第一行告訴我最新的檢查點路徑——我應從該檢查點本地下載3個文件。每個檢查點應該有一個.index,.meta,和.data文件。將它們保存在本地目錄中后,我可以使用對象檢測的export_inference_graph腳本將它們轉換為ProtoBuf。要運行以下腳本,你需要定義MobileNet配置文件的本地路徑、訓練時下載的模型檢查點編號以及要導出的圖形目錄名稱:
# Run this script from tensorflow/models/research:
python object_detection/export_inference_graph.py
--input_type encoded_image_string_tensor
--pipeline_config_path ${LOCAL_PATH_TO_MOBILENET_CONFIG}
--trained_checkpoint_prefix model.ckpt-${CHECKPOINT_NUMBER}
--output_directory ${PATH_TO_YOUR_OUTPUT}.pb
這個腳本運行后,你將會在.pb輸出目錄中看到一個saved_model/目錄。將saved_model.pb文件上傳到你的云存儲/data目錄中(不要擔心生成其他文件)。
現在你已經準備好將模型部署到ML Engine上了。首先,用gcloud創(chuàng)建你的模型:
gcloud ml-engine models create tswift_detector
然后,通過將模型指向剛剛上傳到云存儲的已保存的ProtoBuf來創(chuàng)建第一個模型版本:
gcloud ml-engine versions create v1 --model=tswift_detector --origin=gs://${YOUR_GCS_BUCKET}/data --runtime-version=1.4
模型部署好后,我將用ML Engine的線上預測API生成新的預測圖像。
第四步:使用Firebase函數和Swift構建預測客戶端
我在Swift中編寫了一個iOS客戶端來對我的模型進行預測請求。Swift客戶端將圖像上傳到云存儲,云存儲觸發(fā)Firebase函數,在Node.js中發(fā)起預測請求,并將生成的預測圖像和數據保存到云存儲和Firebase中。
首先,在我的Swift客戶端中,我添加了一個按鈕,供用戶訪問設備的圖片庫。用戶選擇照片后,會觸發(fā)將圖像上傳到云端存儲的操作:
let firestore = Firestore.firestore()
func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [String : Any]) {
let imageURL = info[UIImagePickerControllerImageURL] as? URL
let imageName = imageURL?.lastPathComponent
let storageRef = storage.reference().child("images").child(imageName!)
storageRef.putFile(from: imageURL!, metadata: nil) { metadata, error in
if let error = error {
print(error)
} else {
print("Photo uploaded successfully!")
// TODO: create a listener for the image's prediction data in Firestore
}
}
}
dismiss(animated: true, completion: nil)
}
接下來,我編寫了在上傳到云存儲時觸發(fā)的Firebase函數(https://github.com/sararob/tswift-detection/blob/master/firebase/functions/index.js)。下面的代碼也包含了我向ML Engine預測API發(fā)出請求的函數部分:
function cmlePredict(b64img, callback) {
return new Promise((resolve, reject) => {
google.auth.getApplicationDefault(function (err, authClient, projectId) {
if (err) {
reject(err);
}
if (authClient.createScopedRequired && authClient.createScopedRequired()) {
authClient = authClient.createScoped([
'https://www.googleapis.com/auth/cloud-platform'
]);
}
var ml = google.ml({
version: 'v1'
});
const params = {
auth: authClient,
name: 'projects/sara-cloud-ml/models/tswift_detector',
resource: {
instances: [
{
"inputs": {
"b64": b64img
}
}
]
}
};
ml.projects.predict(params, (err, result) => {
if (err) {
reject(err);
} else {
resolve(result);
}
});
});
});
}
在ML Engine的反應中,我們得到:
detection_boxes:可以用來標出Taylor Swift周圍的邊框;
detection_scores:為每個檢測框架返回一個置信度值,其中只包括分數高于70%的檢測;
detection_classes:告訴我們與檢測相關的ID。在這種情況下,因為只有一個標簽所以該值總為1。
在函數中,如果檢測到Taylor,則用detection_boxes在圖像中繪制一個邊框以及生成置信度分數。然后將新的帶有邊框的圖像保存到云中,將圖像的文件路徑寫入Cloud Firestore,一邊在iOS應用程序中讀取路徑并下載新圖像:
const admin = require('firebase-admin');
admin.initializeApp(functions.config().firebase);
const db = admin.firestore();
let outlinedImgPath = `outlined_img/${filePath.slice(7)}`;
let imageRef = db.collection('predicted_images').doc(filePath);
imageRef.set({
image_path: outlinedImgPath,
confidence: confidence
});
bucket.upload('/tmp/path/to/new/image', {destination: outlinedImgPath});
最后,在iOS應用程序中,我們可以監(jiān)測圖像Firestore路徑的更新。如果檢測到目標,我會下載這張圖片并在應用程序中顯示這張圖以及可信度分數。這個函數將替換上一個代碼片段中的注釋:
self.firestore.collection("predicted_images").document(imageName!)
.addSnapshotListener { documentSnapshot, error in
if let error = error {
print("error occurred(error)")
} else {
if (documentSnapshot?.exists)! {
let imageData = (documentSnapshot?.data())
self.visualizePrediction(imgData: imageData)
} else {
print("waiting for prediction data...")
}
}
}
好了!現在我們有一款Taylor Swift檢測器了!注意,由于模型只用了140張圖像進行訓練,所以準確度不夠高,可能會把其他人誤認為是Taylor。但是,如果有時間的話,我會收集更多貼有標簽的圖片,并更新模型,發(fā)布到應用商店里。
-
圖像識別
+關注
關注
9文章
520瀏覽量
38304 -
SWIFT
+關注
關注
0文章
116瀏覽量
23813 -
tensorflow
+關注
關注
13文章
329瀏覽量
60558
原文標題:教程帖:用TensorFlow自制Taylor Swift識別器
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關注!文章轉載請注明出處。
發(fā)布評論請先 登錄
相關推薦
評論