我有使用tf.train.MonitoredTrainingSession
來培訓CNN的代碼。Tensorflow恢復`tf.Session`使用`tf.train.MonitoredTrainingSession`保存檢查點
當我創建新的tf.train.MonitoredTrainingSession
時,我可以將checkpoint
目錄作爲輸入參數傳遞給會話,它會自動恢復它能找到的最新保存的checkpoint
。我可以設置hooks
來訓練,直到有一步。例如,如果checkpoint
的步驟是150,000
,我想培訓到200,000
我會把last_step
設置爲200,000
。
只要使用tf.train.MonitoredTrainingSession
保存了最新的checkpoint
,上述過程就可以正常工作。但是,如果我嘗試恢復使用正常的tf.Session
保存的checkpoint
,那麼所有地獄都會崩潰。它無法在圖表和全部中找到某些鍵。
培訓與完成這件事:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
如果checkpoint_dir
屬性有一個文件夾,沒有檢查站,這將從頭開始。如果它有一個checkpoint
從以前的培訓課程中保存,它將恢復最新的checkpoint
並繼續培訓。現在
,我恢復了最新checkpoint
和修改一些變量,並將其保存:
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
正如你所看到的,只是saver.save...
之前我修剪網絡中的所有卷積層。無需描述如何以及爲什麼這樣做。問題在於網絡實際上是經過修改的。然後我將網絡保存到checkpoint
。
現在,如果我在保存的修改後的網絡上部署測試,測試就可以正常工作。然而,當我嘗試運行所保存的checkpoint
的tf.train.MonitoredTrainingSession
,它說:
主要CONV 1/weight_loss /平均在檢查站
也沒有發現,我已經注意到,checkpoint
那被保存爲tf.Session
已保存的checkpoint
的大小的一半tf.train.MonitoredTrainingSession
我知道我做錯了,任何建議如何使這項工作?