2017-04-05 166 views
4

我正在使用Tensorflow Java Api將已創建的Tensorflow模型加載到JVM中。 我用這作爲一個例子:tensorflow/examples/LabelImage.java什麼是Python中的Tensorflow Java Api toGraphDef`相當於什麼?

這裏是我的簡單的Scala代碼:

import java.nio.file.{Files, Path, Paths} 
import org.tensorflow.{Graph, Session, Tensor} 

def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path) 
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb")) 
val g = new Graph() 
g.importGraphDef(graphDef) 
val session = new Session(g) 
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0)) 

我如何保存我的模型來獲得Session和存儲在同一個文件中的圖表。如上面「PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb」中所述。

描述here它提到:

曲線圖的序列化表示,通常被稱爲一個 GraphDef,可以通過toGraphDef生成()並且在其它 語言API當量。

其他語言API中的等價物是什麼?我沒有發現它明顯

注意:我已經看過tensorflow_serving下的mnist_saved_model.py,但通過該過程保存給我一個.pb文件和一個variables文件夾。當我試圖加載.pb文件時,我得到:java.lang.IllegalArgumentException: Invalid GraphDef

+0

我試圖使用https://www.tensorflow.org/api_docs/python/tf/GraphDef#SerializeToString,這有意義將圖加載到會話中,但運行會話時變量不在那裏。 –

回答

1

當前使用tensorflow的Java API,我只找到了如何將graph保存爲graphDef(即沒有它的變量和元數據)。這可以通過剛寫入陣列[字節]對文件進行:

Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef) 

這裏myGraphGraph class的Java對象。

我建議使用此處定義的SavedModel api從Python API保存模型。它會將你的模型保存在一個文件夾中,其中序列化圖形在.pb文件中,而變量在文件夾中。請注意您使用的tag_constants,因爲您需要在scala/java代碼中使用它來使用變量加載模型。然後,可以使用java api中的SavedModelBundle Java類輕鬆加載帶變量的圖表和會話。它會返回一個包裝用圖形和包含的變量值的會話都:

val model = SavedModelBundle.load(modelDir, modelTag) 

如果你已經嘗試過這一點,也許你可以分享你的代碼,看看它爲什麼返回無效GraphDef。

另一種方法是凍結圖形,即將變量節點變爲常量節點,以便所有內容都包含在.pb文件中。詳細信息here凍結部分

+0

'SavedModelBundle'尚未發佈,但我可以嘗試重新編譯並使用它。 [api_docs](https://www.tensorflow.org/api_docs/) 我會嘗試凍結我的模型,看看是否有效。謝謝! –

+1

你是對的@DanielHasegan,你只需要從源代碼構建java api依賴[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java#building-from-source) – nicodri

+0

我設法服務於凍結的模型,並將我的本地庫重新編譯爲最新的1.1版本。謝謝您的幫助 –