2016-05-29 167 views

回答

7

tf.train.Saver類目前只存在於Python中,(i)其是從,你可以從C++運行TensorFlow OPS內置,及(ii)其公開Saver.as_saver_def()方法,可以讓你得到一個SaverDef protocol buffer與名稱您必須運行以保存或恢復模型。

在Python中,你可以得到保存和恢復OPS的名稱如下:

saver = tf.train.Saver(...) 
saver_def = saver.as_saver_def() 

# The name of the tensor you must feed with a filename when saving/restoring. 
print saver_def.filename_tensor_name 

# The name of the target operation you must run when restoring. 
print saver_def.restore_op_name 

# The name of the target operation you must run when saving. 
print saver_def.save_tensor_name 

在C++中,從檢查點恢復,你叫Session::Run(),在檢查點文件爲saver_def.filename_tensor_name名餵養,目標操作爲saver_def.restore_op_name。要保存另一個檢查點,請撥打Session::Run(),再次輸入檢查點文件的名稱saver_def.filename_tensor_name,並獲取saver_def.save_tensor_name的值。

+2

偉大的建議!我必須從一個字符串的末尾刪除「:0」。另外,在恢復模型期間,相對路徑不起作用。 Tensorcreation:'tf :: Tensor string(tf :: DT_STRING,tf :: TensorShape({1,1}));' Feeding string:'string.matrix ()(0,0)= file_path_ + filename;' 執行:'TF_CHECK_OK(session_-> Run({{「save/Const:0」,string}},{},{「save/control_dependency」},nullptr));' – Trevir

+0

@Trevir, mrry:可以請你張貼摘錄嗎?我對tensorflow很陌生,文檔也沒有幫助..我會非常感激你! –

+0

@Surferonthefall:前評論有所有必要的代碼。使用python腳本獲取正確的操作名稱,例如「保存/常數:0」。之後,你可以通過session-> run方法在C++中使用操作名稱。 – Trevir

2

最近的TensorFlow版本包含一些輔助函數,可以在沒有Python的情況下在C++中執行相同的操作。這些是由ProtoBuf在pip-package(${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h)中生成的。

// save 
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); 
checkpointPathTensor.scalar<std::string>()() = "some/path"; 
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; 
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr); 

// restore 
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape()); 
checkpointPathTensor.scalar<std::string>()() = "some/path"; 
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}}; 
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr); 

這是基於恢復模型

def restore(sess, metaGraph, fn): 
    restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all' 
    restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name) 
    filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const' 
    sess.run(restore_op, {filename_tensor_name: fn}) 

對於工作和完整version see here的無證蟒蛇路(more details)。