2017-12-03 131 views
0

以下是我的項目代碼的一部分。如何在TensorFlow中初始化tf.metrics成員?

with tf.name_scope("test_accuracy"): 
    test_mean_abs_err, test_mean_abs_err_op = tf.metrics.mean_absolute_error(labels=label_pl, predictions=test_eval_predict) 
    test_accuracy, test_accuracy_op   = tf.metrics.accuracy(labels=label_pl, predictions=test_eval_predict) 
    test_precision, test_precision_op  = tf.metrics.precision(labels=label_pl, predictions=test_eval_predict) 
    test_recall, test_recall_op    = tf.metrics.recall(labels=label_pl, predictions=test_eval_predict) 
    test_f1_measure = 2 * test_precision * test_recall/(test_precision + test_recall) 
tf.summary.scalar('test_mean_abs_err', test_mean_abs_err) 
tf.summary.scalar('test_accuracy', test_accuracy) 
tf.summary.scalar('test_precision', test_precision) 
tf.summary.scalar('test_recall', test_recall) 
tf.summary.scalar('test_f1_measure', test_f1_measure) 
# validation metric init op 
validation_metrics_init_op = tf.variables_initializer(\ 
     var_list=[test_mean_abs_err_op, test_accuracy_op, test_precision_op, test_recall_op], \ 
     name='validation_metrics_init') 

然而,當我運行它,會出現這樣的錯誤:

Traceback (most recent call last): 
    File "./run_dnn.py", line 285, in <module> 
    train(wnd_conf) 
    File "./run_dnn.py", line 89, in train 
    name='validation_metrics_init') 
    File "/export/local/anaconda2/lib/python2.7/site- 
packages/tensorflow/python/ops/variables.py", line 1176, in 
variables_initializer 
return control_flow_ops.group(*[v.initializer for v in var_list], name=name) 
AttributeError: 'Tensor' object has no attribute 'initializer' 

我意識到,我不能創建一個驗證初始化這樣。當我保存新的檢查點模型並應用新一輪驗證時,我想重新計算相應的指標。所以,我必須重新初始化指標爲零。

但是如何將所有這些指標重置爲零?非常感謝您的幫助!

回答

0

我參考了博客(Avoiding headaches with tf.metrics)後,以下面的方式解決了這個問題。

# validation metrics 
validation_metrics_var_scope = "validation_metrics" 
test_mean_abs_err, test_mean_abs_err_op = tf.metrics.mean_absolute_error(labels=label_pl, predictions=test_eval_predict, name=validation_metrics_var_scope) 
test_accuracy, test_accuracy_op   = tf.metrics.accuracy(labels=label_pl, predictions=test_eval_predict, name=validation_metrics_var_scope) 
test_precision, test_precision_op  = tf.metrics.precision(labels=label_pl, predictions=test_eval_predict, name=validation_metrics_var_scope) 
test_recall, test_recall_op    = tf.metrics.recall(labels=label_pl, predictions=test_eval_predict, name=validation_metrics_var_scope) 
test_f1_measure = 2 * test_precision * test_recall/(test_precision + test_recall) 
tf.summary.scalar('test_mean_abs_err', test_mean_abs_err) 
tf.summary.scalar('test_accuracy', test_accuracy) 
tf.summary.scalar('test_precision', test_precision) 
tf.summary.scalar('test_recall', test_recall) 
tf.summary.scalar('test_f1_measure', test_f1_measure) 
# validation metric init op 
validation_metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=validation_metrics_var_scope) 
validation_metrics_init_op = tf.variables_initializer(var_list=validation_metrics_vars, name='validation_metrics_init')