本期 AI Adventure 中,Yufeng 會(huì)帶領(lǐng)我們會(huì)按照之前分享的最佳實(shí)踐來(lái)試著完整走一遍機(jī)器學(xué)習(xí)的整個(gè)流程。工作量有點(diǎn)大,但是聰明的你應(yīng)該沒(méi)問(wèn)題。
使用 MNIST 數(shù)據(jù)* 來(lái)訓(xùn)練模型常常被看作是機(jī)器學(xué)習(xí)界的「Hello World」例子(使用標(biāo)準(zhǔn)的 MNIST 數(shù)據(jù)訓(xùn)練手寫字符的識(shí)別模型),今天我們跟著 Yufeng 一起,使用更“時(shí)尚”的數(shù)據(jù)開(kāi)啟機(jī)器學(xué)習(xí)的 Hello World 之門。
* 段注:MNIST 是一個(gè)手寫數(shù)字圖像的數(shù)據(jù)集,每幅圖像都由一個(gè)整數(shù)標(biāo)記。它主要用于機(jī)器學(xué)習(xí)算法的性能對(duì)標(biāo)。
“潮”起來(lái)的 Machine Learning
Zalando(來(lái)自德國(guó)的電子商務(wù)公司)決意要讓 MNIST 再“火”一把,前段時(shí)間 Zalando 旗下的研究部門發(fā)布了叫做 Fashion-MNIST 的一個(gè)數(shù)據(jù)集。這是一個(gè)和 MNIST 具有相同格式的數(shù)據(jù)集,唯一的不同在于手寫字符被替換成了服飾、鞋子、挎包等等內(nèi)容。它仍然有 10 個(gè)種類,圖像也仍然是 28x28 像素。
在 GitHub 查看更多對(duì)Fashion-MNIST數(shù)據(jù)集的介紹(中文):
https://github.com/zalandoresearch/fashion-mnist/blob/master/README.zh-CN.md
我們一起訓(xùn)練一個(gè)模型,然后用它來(lái)甄別所屬的服飾品類吧!
線性 Classifier
我們先從構(gòu)建一個(gè)線性的 classifier 開(kāi)始,來(lái)看看怎么操作。同以往一樣,我們用 TensorFlow 的評(píng)估器框架(鏈接參見(jiàn)段后) 來(lái)簡(jiǎn)化編程和維護(hù)?;貞浺幌?,我們會(huì)經(jīng)歷加載數(shù)據(jù)、創(chuàng)建 classifier,然后運(yùn)行訓(xùn)練和評(píng)估等操作。另外還會(huì)用本地模型直接做一些預(yù)測(cè),官方文檔參考:
https://tensorflow.google.cn/get_started/get_started_for_beginners?hl=zh-CN
下面從創(chuàng)建模型開(kāi)始,我們首先把數(shù)據(jù)集中的圖像從 28x28 的像素排布轉(zhuǎn)為 1x784 的形式,然后將之稱為特征列 pixels。此操作類似于 AIA 第三期:無(wú)需數(shù)學(xué)知識(shí),輕松搞定鳶尾花辨識(shí)模型中出現(xiàn)的 flower_features。
feature_columns = [ tf.feature_column.numeric_column( "pixels", shape=784)]classifier = tf.estimator.LinearClassifier( feature_columns=feature_columns, n_classes=10, model_dir=logdir)
下一步創(chuàng)建線性的 classifier。我們有 10 種品類需要做標(biāo)記,而不是之前鳶尾花案例中的三種。
要開(kāi)始訓(xùn)練,我們需要配置數(shù)據(jù)集和輸入函數(shù)。TensorFlow 有內(nèi)置的函數(shù)接受一個(gè) NumPy 型的數(shù)組用于生成輸入函數(shù),此處我們就用它來(lái)簡(jiǎn)化一下。
tf.estimator.inputs.numpy_input_fn( x={'pixels': X}, y=Y, batch_size=batch_size, num_epochs=epochs, shuffle=shuffle)DATA_SETS = input_data.read_data_sets( "/tmp/fashion-mnist")
接著用 input_data 模塊把數(shù)據(jù)集載入,將函數(shù)參數(shù)指向數(shù)據(jù)集下載的位置。
然后通過(guò)調(diào)用 classifier.train() 把 classifier、輸入函數(shù)和數(shù)據(jù)集都結(jié)合起來(lái)。
classifier.train( input_fn=train_input_fn, steps=num_steps)accuracy_score = classifier.evaluate( input_fn=eval_input_fn)['accuracy']
最終,我們進(jìn)行一次評(píng)估來(lái)看看模型表現(xiàn)如何。使用經(jīng)典 MNIST 數(shù)據(jù)集時(shí),此模型常常得到 91% 左右的準(zhǔn)確度。然后,由于時(shí)尚版 MNIST 有更復(fù)雜的數(shù)據(jù)集,所以只得到了略高于 80% 的精確度,甚至有時(shí)更低一些。
怎樣才能改善呢?如 AIA第六期:通過(guò)深度神經(jīng)網(wǎng)絡(luò)再識(shí) Estimator 中提到的那樣進(jìn)行就好了。
轉(zhuǎn)為深度模型
切換到 DNNClassifier 就是換一行代碼的功夫,現(xiàn)在重新開(kāi)始訓(xùn)練,然后評(píng)估看看是否深度模型會(huì)比線性的好一些。
classifier = tf.estimator.DNNClassifier( feature_columns=feature_columns, n_classes=10, hidden_units=[100, 75, 50], model_dir=logdir )
正如第五期:通過(guò) TensorBoard 將模型可視化 中討論的那樣,我們應(yīng)當(dāng)用 TensorBoard 來(lái)橫向并且比較一下兩個(gè)模型。
tensorboard --logdir=models/fashion_mnist/
瀏覽器打開(kāi) http://localhost:6006
TensorBoard
看看 Tensorboard,似乎深度模型并沒(méi)有比線性模型好到哪里去!這很可能是對(duì)超參數(shù)的微調(diào)不到位導(dǎo)致的,參見(jiàn) AIA 第二期:機(jī)器學(xué)習(xí)常見(jiàn)的七個(gè)步驟。
看起來(lái)好像是要一路飆到底…
也許是我們的模型需要更大一些來(lái)容納如此搞復(fù)雜度的模型?抑或訓(xùn)練應(yīng)該更少一些?我們來(lái)試試看。經(jīng)過(guò)屢次調(diào)試微參數(shù),模型的失真度突破性降低了,并且比線性模型得到的精度更高。
深度模型(藍(lán)色對(duì)比線性的紅色線)的失真度保持較低狀態(tài)
達(dá)到這一精度之前在訓(xùn)練中多了些步驟,但是最終得到更高精度又使得這些付出非常值得。
由圖可見(jiàn)線性模型的平緩期來(lái)得比深度網(wǎng)絡(luò)要早。這是由于深度模型復(fù)雜度更高,它們需要的訓(xùn)練時(shí)間更長(zhǎng)。
此時(shí),模型差不多滿足我們的要求了。我們可以將其導(dǎo)出,然后產(chǎn)生一個(gè)可伸縮的時(shí)尚版 MNIST classifier API。至于如何導(dǎo)出,可以參照第四期中給出的詳細(xì)步驟。
預(yù)測(cè)
我們快速回顧一下用評(píng)估器做預(yù)測(cè)的方法。很大程度上,它就像是我們訓(xùn)練和評(píng)估的方式;這也是評(píng)估器(框架)的極大優(yōu)勢(shì)——通用一致的函數(shù)接口。
X = DATA_SETS.test.images[5000:5005]predict_input_fn = tf.estimator.inputs.numpy_input_fn( x={'pixels': X}, batch_size=1, num_epochs=1, shuffle=False)predictions = classifier.predict( input_fn=predict_input_fn)
注意我們這次把 batch_size 指定為 1,num_epochs 指定為 1,shuffle 值為 false。這是因?yàn)槲覀兿胍粗樞蛞粋€(gè)一個(gè)的預(yù)測(cè),一次在所有數(shù)據(jù)上進(jìn)行預(yù)測(cè)。我從評(píng)估所用數(shù)據(jù)集中間挑選了 5 幅圖像用于預(yù)測(cè)。
我選擇這 5 幅的原因不僅僅是因?yàn)樗鼈冊(cè)谡虚g,還因?yàn)檫@些模型中有兩個(gè)是不正確的。兩個(gè)都應(yīng)該是襯衫,但卻被模型認(rèn)為第三個(gè)是包而第五個(gè)是大衣。由此,僅僅考慮圖像的紋理變化這個(gè)因素,你能看到這些樣本比起手寫數(shù)字來(lái)說(shuō)是多么有挑戰(zhàn)性。
后續(xù)步驟
你可以在這個(gè) Gist(鏈接在段后)上看到本次分享中所用來(lái)訓(xùn)練和生成圖像的代碼。你的模型表現(xiàn)如何?你所最終采用的參數(shù)又是什么樣的?在評(píng)論當(dāng)中分享一下吧!
https://gist.github.com/yufengg/2b2fd4b81b72f0f9c7b710fa87077145
精彩提要
后續(xù)的幾期將會(huì)著眼于機(jī)器學(xué)習(xí)生態(tài)的工具,從而幫助你創(chuàng)建自己的操作流程和工具鏈。與此同時(shí)也會(huì)展示更多可以用來(lái)解決機(jī)器學(xué)習(xí)問(wèn)題的模型體系結(jié)構(gòu)。我非常期待能在后面的分享中繼續(xù)為你分析解答!在那之前,不要忘了多使用機(jī)器學(xué)習(xí)!
-
線性
+關(guān)注
關(guān)注
0文章
199瀏覽量
25175 -
機(jī)器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8428瀏覽量
132850 -
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1208瀏覽量
24754
原文標(biāo)題:AIA 系列實(shí)戰(zhàn)篇 | 機(jī)器學(xué)習(xí)的「時(shí)尚版」Hello World
文章出處:【微信號(hào):tensorflowers,微信公眾號(hào):Tensorflowers】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論