我保存了一個訓練有素的LSTM模型,我想恢復預測以在測試中使用它。我試圖按照this post。但我收到錯誤。這裏是我的嘗試:如何在使用Saver的Tensorflow中保存和恢復訓練有素的模型?
x = tf.placeholder('tf.float32', [None, input_vec_size, 1])
y = tf.placeholder('tf.float32')
def recurrent_neural_network(x):
layer = {'weights': tf.Variable(tf.random_normal([n_hidden, n_classes])),
'biases': tf.Variable(tf.random_normal([n_classes]))}
x = tf.transpose(x, [1, 0, 2])
x = tf.reshape(x, [-1, 1])
x = tf.split(x, input_vec_size, 0)
lstm_cell = rnn.BasicLSTMCell(n_hidden, state_is_tuple=True)
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
output = tf.add(tf.matmul(outputs[-1], layer['weights']), layer['biases'])
return output
def train_neural_network(x):
prediction = recurrent_neural_network(x)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
Training ...
saver.save(sess, os.path.join(os.getcwd(), 'my_test_model'))
之後,在訓練階段,我想
def test_neural_network(input_data):
with tf.Session() as sess:
#sess.run(tf.global_variables_initializer())
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
prediction = tf.get_default_graph().get_tensor_by_name("prediction:0")
Calculate features from input_data ...
result = sess.run(tf.argmax(prediction.eval(feed_dict={x: features}), 1))
但是,這將引發以下錯誤:
KeyError: "The name 'prediction:0' refers to a Tensor which does not exist. The operation, 'prediction', does not exist in the graph."
然後我試圖加入: tf.add_to_collection('prediction', prediction)
保存之前,並在恢復後用prediction = tf.get_collection('prediction')[0]
替換。但是,這給了我以下錯誤:
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_2' with dtype float and shape [?,34,1] [[Node: Placeholder_2 = Placeholderdtype=DT_FLOAT, shape=[?,34,1], _device="/job:localhost/replica:0/task:0/cpu:0"]]
我知道的第一個錯誤,我應該以恢復,但prediction
不是tensorflow變量指定一個名稱。我經歷了幾篇以前的文章和文章,但無法提出一個可行的解決方案。所以,我的問題是:
- 我在做一些概念錯誤?如果是這樣,什麼?
- 如果沒有,是否有執行錯誤?我該如何解決它?
謝謝。