0
我試圖運行textsum從tensorflow開源模型, 內seq2seq_attention.py他們使用的主管來管理保存模型, 問題是模型由肌酸檢查站運行的監督者啓動應用後圖...等,但60秒後它不保存模型作爲參數給出,它花了幾個小時來執行下一次保存,我試圖刪除global_step變量仍然是同樣的問題,每次我停止訓練,我必須恢復近從一開始(avg_loss)。有人能告訴我什麼是解決方案嗎?使用TensorFlow監督員
給出的代碼是:
def _Train(model, data_batcher):
"""Runs model training."""
with tf.device('/cpu:0'):
model.build_graph()
saver = tf.train.Saver()
# Train dir is different from log_root to avoid summary directory
# conflict with Supervisor.
summary_writer = tf.summary.FileWriter(FLAGS.train_dir)
sv = tf.train.Supervisor(logdir=FLAGS.log_root,
is_chief=True,
saver=saver,
summary_op=None,
save_summaries_secs=60,
save_model_secs=60,
global_step=model.global_step)
sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
allow_soft_placement=True))
running_avg_loss = 0
step = 0
while not sv.should_stop() and step < FLAGS.max_run_steps:
(article_batch, abstract_batch, targets, article_lens, abstract_lens,
loss_weights, _, _) = data_batcher.NextBatch()
(_, summaries, loss, train_step) = model.run_train_step(
sess, article_batch, abstract_batch, targets, article_lens,
abstract_lens, loss_weights)
summary_writer.add_summary(summaries, train_step)
running_avg_loss = _RunningAvgLoss(
running_avg_loss, loss, summary_writer, train_step)
step += 1
if step % 100 == 0:
summary_writer.flush()
sv.Stop()
return running_avg_loss