2017-10-20 339 views
0

我有使用tf.train.MonitoredTrainingSession來培訓CNN的代碼。Tensorflow恢復`tf.Session`使用`tf.train.MonitoredTrainingSession`保存檢查點

當我創建新的tf.train.MonitoredTrainingSession時,我可以將checkpoint目錄作爲輸入參數傳遞給會話,它會自動恢復它能找到的最新保存的checkpoint。我可以設置hooks來訓練,直到有一步。例如,如果checkpoint的步驟是150,000,我想培訓到200,000我會把last_step設置爲200,000

只要使用tf.train.MonitoredTrainingSession保存了最新的checkpoint,上述過程就可以正常工作。但是,如果我嘗試恢復使用正常的tf.Session保存的checkpoint,那麼所有地獄都會崩潰。它無法在圖表和全部中找到某些鍵。

培訓與完成這件事:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.retrain_dir, 
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps), 
      tf.train.NanTensorHook(loss), 
      _LoggerHook()], 
    config=tf.ConfigProto(
     log_device_placement=FLAGS.log_device_placement)) as mon_sess: 
    while not mon_sess.should_stop(): 
    mon_sess.run(train_op) 

如果checkpoint_dir屬性有一個文件夾,沒有檢查站,這將從頭開始。如果它有一個checkpoint從以前的培訓課程中保存,它將恢復最新的checkpoint並繼續培訓。現在

,我恢復了最新checkpoint和修改一些變量,並將其保存:

saver = tf.train.Saver(variables_to_restore) 

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) 

with tf.Session() as sess: 
    if ckpt and ckpt.model_checkpoint_path: 
    # Restores from checkpoint 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    print(ckpt.model_checkpoint_path) 
    restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps 
    else: 
    print('No checkpoint file found') 
    return 

    prune_convs(sess) 
    saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step) 

正如你所看到的,只是saver.save...之前我修剪網絡中的所有卷積層。無需描述如何以及爲什麼這樣做。問題在於網絡實際上是經過修改的。然後我將網絡保存到checkpoint

現在,如果我在保存的修改後的網絡上部署測試,測試就可以正常工作。然而,當我嘗試運行所保存的checkpointtf.train.MonitoredTrainingSession,它說:

主要CONV 1/weight_loss /平均在檢查站

也沒有發現,我已經注意到,checkpoint那被保存爲tf.Session已保存的checkpoint的大小的一半tf.train.MonitoredTrainingSession

我知道我做錯了,任何建議如何使這項工作?

回答

0

我想通了。顯然,tf.Saver不會從checkpoint恢復所有變量。我試着立即恢復並保存,輸出只有一半的大小。

我用tf.train.list_variables從最新的checkpoint獲得所有變量,然後將它們轉換成tf.Variable並從它們創建了一個dict。然後我通過dicttf.Saver,它恢復了我所有的變量。

接下來的事情是initialize所有的變量,然後修改權重。

現在它工作。

相關問題