2017-10-17 54 views
0

我從https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10以下cifar10教程。 在這個項目中,有6個類。在搜索互聯網後,我瞭解了cifar10.py和cifar10_input.py類。但我無法理解cifar10_train.py中的火車功能。這是cifar10_train.py類中的火車功能。有人可以從cifar10教程張量流中解釋cifar10_train.py中的火車功能

def train(): 
with tf.Graph().as_default(): 
    global_step = tf.contrib.framework.get_or_create_global_step() 

    # get images and labels for cifar 10 
    # Force input pipeline to CPU:0 to avoid operations sometime ending on 
    # GPU and resulting in a slow down 

    with tf.device('/cpu:0'): 
     images, labels = cifar10.distorted_inputs() 

    logits = cifar10.inference(images) 

    loss = cifar10.loss(logits, labels) 

    train_op = cifar10.train(loss, global_step) 

    class _LoggerHook(tf.train.SessionRunHook): 

     def begin(self): 
      self._step = -1 
      self._start_time = time.time() 

     def before_run(self, run_context): 
      self._step += 1 
      return tf.train.SessionRunArgs(loss) 

     def after_run(self, run_context, run_values): 
      if self._step % FLAGS.log_frequency == 0: 
       current_time = time.time() 
       duration = current_time - self._start_time 
       self._start_time = current_time 

       loss_value = run_values.results 
       examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size/duration 
       sec_per_batch = float(duration/FLAGS.log_frequency) 

       format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
           'sec/batch)') 
       print(format_str % (datetime.now(), self._step, loss_value, 
            examples_per_sec, sec_per_batch)) 

    with tf.train.MonitoredTrainingSession(
      checkpoint_dir=FLAGS.train_dir, 
      hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 
        tf.train.NanTensorHook(loss), 
        _LoggerHook()], 
      config=tf.ConfigProto(
       log_device_placement=FLAGS.log_device_placement)) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

有人能解釋_LoggerHook類中發生了什麼嗎?

回答

0

它使用MonitoredSessionSessionRunHook記錄訓練時的損失。

_LoggerHookSessionRunHook的實現,在下面描述的順序運行:

call hooks.begin() 
    sess = tf.Session() 
    call hooks.after_create_session() 
    while not stop is requested: 
    call hooks.before_run() 
    try: 
     results = sess.run(merged_fetches, feed_dict=merged_feeds) 
    except (errors.OutOfRangeError, StopIteration): 
     break 
    call hooks.after_run() 
    call hooks.end() 
    sess.close() 

這是一個從here

它在session.run然後以預定義格式輸出loss之前收集loss數據。

一個教程:https://www.tensorflow.org/tutorials/layers

希望這希望。

相關問題