TensorFlow模型詳解與應(yīng)用
推薦 + 挑錯(cuò) + 收藏(0) + 用戶評(píng)論(0)
DNNLinearCombinedClassifier 類繼承于類 Estimator,Estimator 類繼承于類 BaseEstimator。BaseEstimator 是一個(gè)抽象類,定義了通用的模型訓(xùn)練以及評(píng)測(cè)的函數(shù)接口 (train_model, evaluate_model, infer_model),Estimator 類中用一個(gè)統(tǒng)一函數(shù) call_model_fn 來(lái)實(shí)現(xiàn) train_model, evaluate_model, infer_model。
圖 7 estimator 的類關(guān)系圖
為了更好了解整個(gè)過(guò)程,我們看看內(nèi)部函數(shù)的調(diào)用過(guò)程(代碼可以參見(jiàn) estimator/estimator.py):
圖 8 Estmiator 類的函數(shù)調(diào)用圖
模型訓(xùn)練通過(guò)調(diào)用 BaseEstimator 的 fit() 接口開(kāi)始,其調(diào)用棧是:fit -》 _train_model -》 _get_train_ops -》_call_model_fn(ModelKeys.TRAIN) -》 _model_fn,最終_model_fn() 產(chǎn)生模型并通過(guò) export 函數(shù)將模型輸出到 model_dir 對(duì)應(yīng)目錄中。
我們把訓(xùn)練模型的調(diào)用過(guò)程在代碼級(jí)別展開(kāi),標(biāo)出關(guān)鍵的幾個(gè)函數(shù)和數(shù)據(jù)結(jié)構(gòu),省略不關(guān)鍵的代碼,希望能讓讀者看到訓(xùn)練模型的大致過(guò)程:
圖 9 模型訓(xùn)練的調(diào)用棧
評(píng)測(cè)(evaluate)和預(yù)測(cè)(predict)的過(guò)程與訓(xùn)練(train)大致相同,讀者可以通過(guò)源代碼文件找到對(duì)應(yīng)函數(shù)了解??梢钥闯?,整個(gè)函數(shù)調(diào)用棧中最關(guān)鍵的 2 個(gè)函數(shù)是: input_fn 和 model_fn。input_fn 從輸入數(shù)據(jù)中生成 features 和 labels,features 是一個(gè) Tensor 或者是一個(gè)從特征名到 Tensor 的字典,如果 features 是一個(gè) Tensor,程序會(huì)給這個(gè) Tensor 一個(gè)空字符串的鍵值,轉(zhuǎn)換成特征名到 Tensor 的字典。labels 是樣本的 label 構(gòu)成的 tensor。input_fn 由應(yīng)用程序調(diào)用者提供實(shí)現(xiàn),返回(features, labels)二元組,要求 tf.get_shape(features)[0] == tf.get_shape(labels)[0],也就是兩個(gè) tensor 的行數(shù)目得保持一致。model_fn 定義訓(xùn)練和評(píng)測(cè)模型的具體邏輯,如模型訓(xùn)練產(chǎn)生的誤差 (model_fn_ops.loss) 以及訓(xùn)練算子(model_fn_ops.train_op)通過(guò)封裝在 EstmiatorSpec 的對(duì)象中由 training 的 Session 進(jìn)行調(diào)用。每個(gè)具體模型需要實(shí)現(xiàn)的是自定義的 model_fn。
DNNLinearCombinedClassifier 是如何實(shí)現(xiàn)自己的 model_fn 的呢?本文開(kāi)頭我們給出了它的初始化函數(shù)原型,進(jìn)入初始化函數(shù)的實(shí)現(xiàn)中我們定位到代碼行 model_fn=_dnn_linear_combined_model_fn。
這個(gè)就是 DNNLinearCombinedClassifier 的 model_fn。這個(gè)函數(shù)的定義如下:
def_dnn_linear_combined_model_fn(features, labels, mode, params, config= None)
features 和 labels 大家都已經(jīng)知道,mode 指定 model_fn 的操作模式,目前支持 3 個(gè)值:訓(xùn)練模型 (model_fn.ModeKeys.TRAIN),對(duì)模型進(jìn)行評(píng)測(cè) (model_fn.ModeKeys.EVAL),根據(jù)輸入特征進(jìn)行預(yù)測(cè) (model_fn.ModeKeys.PREDICT),mode 的定義可參見(jiàn)文件 estimator/model_fn.py。params 和 config 參數(shù)分別定義模型訓(xùn)練的參數(shù)以及模型運(yùn)行的配置。
非常好我支持^.^
(2) 40%
不好我反對(duì)
(3) 60%