0
我對Tensorflow非常非常全新,需要編寫一個腳本來測試從檢查點文件恢復的模型上的單個示例。測試恢復的張量流模型的一般方法
我想知道是否有一種通用的方法來爲恢復的模型構建測試函數,而無需知道模型的所有細節。
此外,在下面的代碼的最後一部分,這看起來像我朝着正確的方向?如果是這樣,那麼如何在不瞭解模型細節的情況下構建「y」?
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
from fuel.datasets.hdf5 import H5PYDataset
ckpt_path='ckt/mnist/mnist_2017_02_23_17_22_50/mnist_2017_02_23_17_22_50_5000.ckpt'
##############################
#### Initialize Variables ####
##############################
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
var_to_shape_map = reader.get_variable_to_shape_map()
var=[0]*len(var_to_shape_map)
i=0
for key in var_to_shape_map:
var[i] = tf.Variable(reader.get_tensor(key), name=key)
#print("tensor_name: ", key)
#print(reader.get_tensor(key))
i=i+1
initialize=tf.global_variables_initializer()
###############################
####### Restore Model #########
###############################
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, ckpt_path)
###############################
##### Get Example to Test #####
###############################
test_set = H5PYDataset('../CNN3D/data/bmnist.hdf5', which_sets=('test',))
handle = test_set.open()
for i in range(0,100):
test_data = test_set.get_data(handle, slice(i, i+1))
if test_data[1][0][0]==8:
model_idx=i
test_data = test_set.get_data(handle, slice(model_idx,model_idx+1))
data = tf.Variable(np.asarray(test_data[0][0][0]), name='data')
###############################
######## Test Example #########
###############################
x = tf.placeholder(tf.float32,shape=[28,28])
y = ???
sess.run(initialize)
result=sess.run(y, feed_dict={x: data})
print result