2017-02-16 73 views
0

我正試圖在使用python的GPU上訓練圖形,以從C++進程加載圖形。C++等價於train.import_meta_graph clear_devices參數?

status = ReadBinaryProto(Env::Default(), "model.pb", &graph_def); 
session->Create(graph_def); 

然後,我得到錯誤信息

「,因爲沒有設備匹配規範在此過程中註冊,不能將某個設備指定節點...;可用設備:/職務:本地主機/副本:0 /任務:0/cpu:0「

對於python train.import_meta_graph API具有clear_devices參數,但它在C++ API上的等效參數是什麼?

對於加載圖形,我使用Tensorflor在使用CMake構建的Windows上使用-Dtensorflow_ENABLE_GPU = ON,因此我的vcxproj具有GOOGLE_CUDA定義。

我讀過Tensorflow, restore variables in a specific device但它只適用於python API。

回答

0

鑑於您是從Python導出圖表,也許您可​​以導出一個清除了設備的圖表?喜歡的東西:

meta_graph = tf.train.export_meta_graph() 
with tf.Graph().as_default(): 
    tf.train.import_meta_graph(meta_graph, clear_devices=True) 
    # Export the GraphDef now 
    with open('/tmp/model.pb', 'w') as f: 
     f.write(tf.get_default_graph().as_graph_def().SerializeToString()) 

或者,你可以通過在圖中清理出的每一個節點的device領域的重複使用C++ clear_devices=True的行爲。喜歡的東西:

status = ReadBinaryProto(Env::Default(), "model.pb", &graph_def); 
for (int n = 0; n < graph_def.node_size(); ++n) { 
    graph_def.mutable_node(n)->clear_device(); 
} 
session->Create(graph_def); 

但我建議不要認爲既然是依託GraphDef s的如何的框架,這可能是脆弱的消耗內部細節。

+0

因爲我使用train.Saver(),而不是python的export_meta_graph(),我清除了C++中的設備,並且Session :: Create()成功使用該圖!謝謝! – Jay