2017-07-24 86 views
1

我試圖通過遵循運行時統計指令here來獲得我的張量流代碼配置文件(網絡中每層的運行和內存消耗)。據我瞭解,我需要創建運行選項像這樣通過tensorflow監測訓練獲得運行時統計信息

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 
run_metadata = tf.RunMetadata() 

運行的元數據,並將其傳遞給sess.run

然而,正如我也嘗試使用tf.train.MonitoredTrainingSession我不知道如果我可以將同樣的事情傳遞給這個班級。一個合理的方法可以使用鉤子,但我不知道如何去做。我對他們還很陌生

回答

2

你可以簡單地創建一個自定義鉤子並將它傳遞給MonitoredTrainingSession。無需將您自己的tf.RunMetadata()實例傳遞給運行調用。

下面是一個例子鉤,其存儲每N個步驟ckptdir元數據:

import tensorflow as tf 

class TraceHook(tf.train.SessionRunHook): 
    """Hook to perform Traces every N steps.""" 

    def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE): 
     self._trace = every_step == 1 
     self.writer = tf.summary.FileWriter(ckptdir) 
     self.trace_level = trace_level 
     self.every_step = every_step 

    def begin(self): 
     self._global_step_tensor = tf.train.get_global_step() 
     if self._global_step_tensor is None: 
      raise RuntimeError("Global step should be created to use _TraceHook.") 

    def before_run(self, run_context): 
     if self._trace: 
      options = tf.RunOptions(trace_level=self.trace_level) 
     else: 
      options = None 
     return tf.train.SessionRunArgs(fetches=self._global_step_tensor, 
             options=options) 

    def after_run(self, run_context, run_values): 
     global_step = run_values.results - 1 
     if self._trace: 
      self._trace = False 
      self.writer.add_run_metadata(run_values.run_metadata, 
             f'{global_step}', global_step) 
     if not (global_step + 1) % self.every_step: 
      self._trace = True 

它檢查在before_run它是否有跟蹤與否,如果是,增加了RunOptions。在after_run它檢查是否需要跟蹤下一個運行調用,如果是,它會再次將_trace設置爲True。此外,它在元數據可用時存儲元數據。