2017-03-07 42 views
1

當我定義一個TF1.0自定義model_fn時,我想停止訓練,當失敗是楠。我想下面的代碼在model_fn:tf.train.NanTensorHook(loss,fail_on_nan_loss = False)仍然會引發TF1.0異常

return model_fn_lib.ModelFnOps(
     mode=mode, 
     predictions=predictions_dict, 
     loss=loss, 
     train_op=train_op, 
     eval_metric_ops=eval_metric_ops, 
     training_hooks=[tf.train.NanTensorHook(loss, fail_on_nan_loss=False)]) 

fail_on_nan_loss =假仍然會拋出異常,我希望它會寫警告消息並沒有引發異常停止具體培訓。

有關如何正確使用tf.train.NanTensorHook的任何建議?

+0

這是需要使用Hyperparams來查找哪些組合設置更好,如[鏈接]中所建議的(http://www.michael-remington.com/machine/learning/tensorflow/neural/networks/2016/06/ 25/tflearn-tutorial.html)。您不希望由於引發的異常導致大循環被打亂。 – xiyulangzi

回答

0

當我探索的解決方案,一種可能的解決辦法力量幫助: 我從basic_session_run_hooks.py複製NanTensorHook類,並讓自己的通話版本我model_fn裏面如下

 class NanTensorHook2(tf.train.SessionRunHook): 
     """NaN Loss monitor by Lei. 

     Monitors loss and stops training if loss is NaN. 
     Can either fail with exception or just stop training. 
     """ 

     def __init__(self, loss_tensor, fail_on_nan_loss=True): 
     """Initializes NanLoss monitor. 

     Args: 
      loss_tensor: `Tensor`, the loss tensor. 
      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 
     """ 
     self._loss_tensor = loss_tensor 
     self._fail_on_nan_loss = fail_on_nan_loss 

     def before_run(self, run_context): # pylint: disable=unused-argument 
     return tf.train.SessionRunArgs(self._loss_tensor) 

     def after_run(self, run_context, run_values): 
     if (np.isnan(run_values.results) or np.isinf(run_values.results)): 
      failure_message = "Model diverged with loss = NaN or Inf." 
      if self._fail_on_nan_loss: 
      logging.error(failure_message) 
      raise NanLossDuringTrainingError 
      else: 
      logging.warning(failure_message) 
      # We don't raise an error but we request stop without an exception. 
      run_context.request_stop() 

然後用NanTensorHook2代替,然後開始工作。

注意我添加了「np.isinf(run_values.results)」因爲我相信loss = inf也應該在這裏檢查。

任何專家有更好的解決方案?