我已經建立了一個自動編碼器,將VGG19.relu4_1
的激活「轉換」爲像素。我使用tensorflow.contrib.layers
中的新便利功能(如在TF 0.10rc0中)。該代碼與TensorFlow的CIFAR10教程具有相似的佈局,其中train.py
將訓練和檢查點設置爲磁盤模型,一個eval.py
輪詢新檢查點文件並對它們運行推斷。故障恢復檢查點TensorFlow網
我的問題是,評估從來沒有像培訓一樣好,既不是在損失函數的價值方面,也不是當我看輸出圖像時(即使在與培訓相同的圖像上運行時)。這讓我覺得它與恢復過程有關。
當我看着TensorBoard培訓的輸出時,它看起來不錯(最終),所以我不認爲我的網本身有什麼問題。
我的網看起來像這樣:
import tensorflow.contrib.layers as contrib
bn_params = {
"is_training": is_training,
"center": True,
"scale": True
}
tensor = contrib.convolution2d_transpose(vgg_output, 64*4, 4,
stride=2,
normalizer_fn=contrib.batch_norm,
normalizer_params=bn_params,
scope="deconv1")
tensor = contrib.convolution2d_transpose(tensor, 64*2, 4,
stride=2,
normalizer_fn=contrib.batch_norm,
normalizer_params=bn_params,
scope="deconv2")
.
.
.
而在train.py
我這樣做是爲了保存檢查點:
variable_averages = tf.train.ExponentialMovingAverage(mynet.MOVING_AVERAGE_DECAY)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train')
while training:
# train (with batch normalization's is_training = True)
if time_to_checkpoint:
saver.save(sess, checkpoint_path, global_step=step)
在eval.py
我這樣做:
# run code that creates the net
variable_averages = tf.train.ExponentialMovingAverage(
mynet.MOVING_AVERAGE_DECAY)
saver = tf.train.Saver(variable_averages.variables_to_restore())
while polling:
# sleep and check for new checkpoint files
with tf.Session() as sess:
init = tf.initialize_all_variables()
init_local = tf.initialize_local_variables()
sess.run([init, init_local])
saver.restore(sess, checkpoint_path)
# run inference (with batch normalization's is_training = False)
藍色是訓練損失,橙色是eval損失。
感謝您的解決。我是唯一一個認爲這應該被充分記錄/修復的人。我認爲'optimize_loss()'函數只是optimizer.minimize(損失,步驟)的快捷方式,而不是其他contrib.layers像宣傳的那樣工作所必需的。 – DomJack