2017-04-24 133 views
0

我試圖運行textsum從tensorflow開源模型, 內seq2seq_attention.py他們使用的主管來管理保存模型, 問題是模型由肌酸檢查站運行的監督者啓動應用後圖...等,但60秒後它不保存模型作爲參數給出,它花了幾個小時來執行下一次保存,我試圖刪除global_step變量仍然是同樣的問題,每次我停止訓練,我必須恢復近從一開始(avg_loss)。有人能告訴我什麼是解決方案嗎?使用TensorFlow監督員

給出的代碼是:

def _Train(model, data_batcher): 
    """Runs model training.""" 
    with tf.device('/cpu:0'): 
    model.build_graph() 
    saver = tf.train.Saver() 
    # Train dir is different from log_root to avoid summary directory 
    # conflict with Supervisor. 
    summary_writer = tf.summary.FileWriter(FLAGS.train_dir) 
    sv = tf.train.Supervisor(logdir=FLAGS.log_root, 
          is_chief=True, 
          saver=saver, 
          summary_op=None, 
          save_summaries_secs=60, 
          save_model_secs=60, 
          global_step=model.global_step) 
    sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
     allow_soft_placement=True)) 
    running_avg_loss = 0 
    step = 0 
    while not sv.should_stop() and step < FLAGS.max_run_steps: 
     (article_batch, abstract_batch, targets, article_lens, abstract_lens, 
     loss_weights, _, _) = data_batcher.NextBatch() 
     (_, summaries, loss, train_step) = model.run_train_step(
      sess, article_batch, abstract_batch, targets, article_lens, 
      abstract_lens, loss_weights) 
     summary_writer.add_summary(summaries, train_step) 
     running_avg_loss = _RunningAvgLoss(
      running_avg_loss, loss, summary_writer, train_step) 
     step += 1 
     if step % 100 == 0: 
     summary_writer.flush() 
    sv.Stop() 
    return running_avg_loss 

回答

0

你嘗試,當你實例化你的金丹指定保存的時間?我的意思是(用於每15分鐘節省一個模型):

saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.25)