在TensorFlow中讀數(shù)據(jù)一般有三種方法:
使用placeholder讀內(nèi)存中的數(shù)據(jù)
使用queue讀硬盤中的數(shù)據(jù)
使用Dataset讀內(nèi)存?zhèn)€硬盤中的數(shù)據(jù)
基本概率
由于第三種方法在語(yǔ)法上更簡(jiǎn)潔,因此本文主要介紹第三種方法。官方給出的Dataset API類圖:
image.png
其中終于重要的兩個(gè)基礎(chǔ)類:Dateset和Iterator。Dateset是具有相同類型的“元素”的有序表,元素可以是向量、字符串、圖片等。
從內(nèi)存中創(chuàng)建Dataset
以數(shù)字元素為例:
例1
從Dataset中實(shí)例化一個(gè)Iterator,然后對(duì)Iterator進(jìn)行迭代。
iterator = dataset.make_one_shot_iterator()
從dataset中實(shí)例化一個(gè)iterator,是“one shot iterator”,即只能從頭到尾讀取一次。
one_element = iterator.get_next()
從iterator中取出一個(gè)元素, one_element是一個(gè)tensor,因此需要調(diào)用sess.run(one_element)取出值。
如果元素被讀取完了,再sess.run(one_element)會(huì)拋出tf.errors.OutOfRangeError異常。解決方法:使用 dataset.repeat()
更復(fù)雜的輸入形式,例如,在圖像識(shí)別的應(yīng)用中,一個(gè)元素可以使{“image”:image_tensor, “l(fā)abel”:lable_tensor}
dataset = tf.data.Dataset.from_tensor_slices( { "a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), "b": np.random.uniform(size=(5, 2)) } )
最終dataset中的一個(gè)元素為{"a": 1.0, "b": [0.9, 0.1]}的形式?;蛘?/p>
dataset = tf.data.Dataset.from_tensor_slices( (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2))) )
對(duì)Dataset中的元素做變換:Transformation
一個(gè)Dataset通過(guò)Transformation變成一個(gè)新的Dataset。常用的操作有:
map
batch
shuffle
repeat
下面分別來(lái)介紹以上幾個(gè)操作。(1)mapmap接收一個(gè)函數(shù),dataset中的每個(gè)元素都可以作為這個(gè)函數(shù)的輸入,并將函數(shù)的返回值作為新的dataset,例如:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
(2)batch將多個(gè)元素組合成batch,例如:
dataset = dataset.batch(32)
(3)shuffle打亂dataset中的元素,參數(shù)buffersize表示打亂時(shí)buffer的大小。
dataset = dataset.shuffle(buffer_size=10000)
(4)repeat將整個(gè)序列重復(fù)多次,只用用來(lái)處理epoch。如果直接調(diào)用repeat()的話,生成的序列就會(huì)無(wú)限重復(fù)下去,沒(méi)有結(jié)束,因此也不會(huì)拋出。tf.errors.OutOfRangeError異常:
dataset = dataset.repeat(5)
例子:讀磁盤圖片與對(duì)應(yīng)的label
讀入磁盤中的圖片和圖片相應(yīng)的label,并將其打亂,組成batch_size=32的訓(xùn)練樣本。在訓(xùn)練時(shí)重復(fù)10個(gè)epoch。
# 函數(shù)的功能時(shí)將filename對(duì)應(yīng)的圖片文件讀進(jìn)來(lái),并縮放到統(tǒng)一的大小def _parse_function(filename, label): image_string = tf.read_file(filename) image_decoded = tf.image.decode_image(image_string) image_resized = tf.image.resize_images(image_decoded, [28, 28]) return image_resized, label# 圖片文件的列表filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])# label[i]就是圖片filenames[i]的labellabels = tf.constant([0, 37, ...])# 此時(shí)dataset中的一個(gè)元素是(filename, label)dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))# 此時(shí)dataset中的一個(gè)元素是(image_resized, label)dataset = dataset.map(_parse_function)# 此時(shí)dataset中的一個(gè)元素是(image_resized_batch, label_batch)dataset = dataset.shuffle(buffersize=1000).batch(32).repeat(10)# 此時(shí)dataset中的一個(gè)元素是(image_resized_batch, label_batch)# image_resized_batch的形狀為(32, 28, 28, 3), label_batch的形狀為(32, )
-
函數(shù)
+關(guān)注
關(guān)注
3文章
4344瀏覽量
62864 -
tensorflow
+關(guān)注
關(guān)注
13文章
329瀏覽量
60584 -
DataSet
+關(guān)注
關(guān)注
0文章
5瀏覽量
2209
原文標(biāo)題:TensorFlow讀數(shù)據(jù)
文章出處:【微信號(hào):C_Expert,微信公眾號(hào):C語(yǔ)言專家集中營(yíng)】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論