2017-03-15 61 views
1

我正嘗試使用新的Java API從磁盤讀取模型。如何在Tensorflow的Java API中使用`saver.save`加載模型保存

The one example要使用Tensorflow的Java API顯示如何讀取具有圖形定義和參數權重的.pb模型文件。

在Python方面,Tensorflow建議使用Saver對象將模型保存到磁盤。它會創建一個.meta文件,該文件具有該定義並且具有.data文件的權重。在Python中,我使用new_saver=tf.train.import_meta_graph(var_filename) new_saver.restore(sess, model_filename)從磁盤讀取模型。

如何在Java API中執行此操作?

回答

0

SavedModelBundle類可能是你正在尋找的。特別是,SavedModelBundle.load()將返回一個Session,您可以使用它來執行保存的模型。

請注意,此功能最近添加到Java API中,因此它不存在於二進制發行版中,因此您必須在發佈TensorFlow 1.1之前build the Java API from source

+0

很好,謝謝。我目前的解決方案是使用'freeze_graph'保存圖形def和權重,然後用Java讀取。這個班看起來很有前途我會等到官方發佈的代碼嘗試它, –

0

我正在做類似的事情,使用python界面在hadoop集羣上訓練模型,並使用模型和學習參數在java中進行預測。

用法是在Java端很簡單:

SavedModelBundle load = SavedModelBundle.load(modelDir, "serve"); 
     float[][] resultArray; 
     try (Graph g = load.graph()) { 
      try (Session s = load.session(); 
       Tensor result = s.runner().feed("data", data).fetch("prediction").run().get(0)) { 
       resultArray = result.copyTo(new float[10][1]); 
      } 
     } 
     load.close(); 
     return resultArray; 

要獲取的名稱Feed和獲取的操作可以打印簽名,並使用輸入和輸出值名稱。

print(prediction_signature) 

https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py#L119

相關問題