2017-03-03 103 views
0

我的代碼是如下:我不能保存,並與tensorflow佔位恢復一個變量,由於

import tensorflow as tf 
import numpy as np 
def add_layer(input): 
    v2 = tf.Variable(tf.random_normal([2, 2], dtype=tf.float32, name='v2')) 
    tf.add_to_collection('h0_v2',v2) 
    output=tf.matmul(input,v2) 
    return output 
x1=tf.placeholder(tf.float32) 
outputs=add_layer(x) 
tf.add_to_collection('outputs', outputs) 
saver = tf.train.Saver() 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    x1=np.random.random([2, 2]) 
    print(sess.run(outputs,feed_dict={x:x1})) 
    save_path = saver.save(sess, './model.ckpt') 
    print("model saved in file:", save_path) 

,然後另一碼是拼命地跑:

import tensorflow as tf 
import numpy as np 
sess = tf.Session() 
saver = tf.train.import_meta_graph('./model.ckpt.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
x2=np.random.random([2, 2]) 
print(sess.run(tf.get_collection('outputs',feed_dict={x:x2}))) 
print('model is loaded') 
sess.close() 

然後電腦告訴我'x'沒有定義,我不知道什麼是錯的。

+0

但第一代碼無故障運行? – CrisH

回答

0

我會說這樣做:

import tensorflow as tf 
import numpy as np 
sess = tf.Session() 
saver = tf.train.Saver() 
saver.restore(sess, './model.ckpt') 
x2=np.random.random([2, 2]) 
print(sess.run(tf.get_collection('outputs',feed_dict={x:x2}))) 
print('model is loaded') 
sess.close() 

我發現這對tensorflow的網站。希望能幫助到你。

+0

非常感謝! – quan

+0

它現在在工作嗎? :) – CrisH

+0

是的,但是當我運行它兩次,錯誤將被拋出,說:ValueError:至少有兩個變量具有相同的名稱:變量/ Adadelta_1。 – quan

0

我找到解決問題的方法:

import tensorflow as tf 
import numpy as np 
def add_layer(input): 
    #v1 = tf.Variable(np.random.random([2, 2]), dtype=tf.float32, name='v1') 
    v2 = tf.Variable(tf.random_normal([2, 2], dtype=tf.float32, name='v2')) 
    tf.add_to_collection('h0_v2',v2) 
    output=tf.matmul(input,v2) 
    return output 
x=tf.placeholder(tf.float32) 
outputs=add_layer(x) 
saver = tf.train.Saver() 
sess = tf.Session() 
saver = tf.train.import_meta_graph('./model.ckpt.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
x2=np.random.random([2, 2]) 
print(sess.run(outputs,feed_dict={x:x2})) 
print('model is loaded') 
sess.close() 
相關問題