2017-08-31 217 views
0

這是一個如何保存和恢復訓練模型的例子。 希望這會對初學者有所幫助。tensorflow:如何保存/恢復訓練有素的模型

生成1個帶relu激活函數的隱層神經網絡。 (聽說relu已被證明比sigmoid好多了,特別是對於隱層數量衆多的神經網絡。)

訓練數據顯然是異或。

火車和保存 「tf_train_save.py」

import tensorflow as tf 
import numpy as np 

x = np.matrix([[0, 0], [0, 1], [1, 0], [1, 1]]) 
y = np.matrix([[0], [1], [1], [0]]) 

n_batch = x.shape[0] 
n_input = x.shape[1] 
n_hidden = 5 
n_classes = y.shape[1] 

X = tf.placeholder(tf.float32, [None, n_input], name="X") 
Y = tf.placeholder(tf.float32, [None, n_classes], name="Y") 

w_h = tf.Variable(tf.random_normal([n_input, n_hidden], stddev=0.01), tf.float32, name="w_h") 
w_o = tf.Variable(tf.random_normal([n_hidden, n_classes], stddev=0.01), tf.float32, name="w_o") 

l_h = tf.nn.relu(tf.matmul(X, w_h)) 
hypo = tf.nn.relu(tf.matmul(l_h, w_o), name="output") 

cost = tf.reduce_mean(tf.square(Y-hypo)) 
train = tf.train.GradientDescentOptimizer(0.1).minimize(cost) 

init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    sess.run(init) 

    for epoch in range(1000): 
     for i in range(4): 
      sess.run(train, feed_dict = {X:x[i,:], Y:y[i,:]}) 

    result = sess.run([hypo, tf.floor(hypo+0.5)], feed_dict={X:x}) 

    print(*result[0]) 
    print(*result[1]) 

    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"]) 
    tf.train.write_graph(output_graph_def, "./logs/mp_logs", "test.pb", False) 

負荷 「tf_load.py」

import tensorflow as tf 
from tensorflow.python.platform import gfile 
import numpy as np 

x = np.matrix([[0, 0], [0, 1], [1, 0], [1, 1]]) 
y = np.matrix([[0], [1], [1], [0]]) 

with gfile.FastGFile("./logs/mp_logs/test.pb",'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    tf.import_graph_def(graph_def, name='') 

with tf.Session() as sess: 
    X = sess.graph.get_tensor_by_name("X:0") 
    print(X) 
    output = sess.graph.get_tensor_by_name("output:0") 
    print(output) 

    tf.global_variables_initializer().run() 

    result = sess.run([output, tf.floor(output+0.5)], feed_dict={X:x}) 

    print(*result[0]) 
    print(*result[1]) 

會有更簡單的方法?

+0

您的問題標題似乎不符合您的要求。假設題目問題,你的代碼是否符合你的期望?我想知道加載腳本中的初始化。 –

+0

你的力量保存你的權重變量,因爲你加載它們,所以你的代碼是不正確的。看看這個https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model?rq=1 –

+0

@EricPlaton它的工作原理。我只是好奇,如果有更簡單的方法。像...保存張量名稱一樣。 –

回答

0

您使用的是convert_variables_to_constants,所以您在訓練方面確實很棒。對於路人來說,該API出現在v1.0中(如果我在跟蹤API後沒有弄錯)。

在負載方面,我認爲最小代碼是一個命令更短。鑑於您已將所有變量轉換爲常量,因此在恢復時沒有變量可以初始化。所以行:

tf.global_variables_initializer().run() 

什麼都不做。從v1.3的docs開始:

但是,如果var_list爲空,該函數仍會返回可以運行的Op。該操作只是沒有效果。

加載腳本沒有全局變量,並且因爲tf.global_variables_initializer()等於tf.variables_initializer(tf.global_variables()),所以該操作是空操作。

+1

我期待恢復時不處理張量名稱,如'輸入'和'輸出'。找不到例子。 我認爲讀取VGGish源代碼是可能的。但我誤解了它。他們只是做了一個定義圖的函數,並在生成和恢復函數中使用它們。 猜猜我必須做同樣的事情,一起處理圖形文件和py文件 –

相關問題