2017-10-17 90 views
0

我想用tf.metrics.accuracy跟蹤我的預測的準確性,但我不確定如何使用update_op(以下acc_update_op),該函數返回:如何使用tf.metrics.accuracy?

accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions) 

我的想法是將其添加到tf.GraphKeys.UPDATE_OPS會有道理,但我不知道如何做到這一點。

回答

4

tf.metrics.accuracy是許多流量度量TensorFlow操作之一(另一個是tf.metrics.recall)。創建後,會創建兩個變量(counttotal),以累積所有最終結果的傳入結果。第一個返回值是用於計算count/total的張量。返回的第二個op是一個更新這些變量的有狀態函數。當評估多個批次數據的分類器性能時,流播的度量函數很有用。用一個簡單的例子:

# building phase 
with tf.name_scope("streaming"): 
    accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions) 

test_fetches = { 
    'accuracy': accuracy, 
    'acc_op': acc_update_op 
} 

# when testing the classifier  
with tf.name_scope("streaming"): 
    # clear counters for a fresh evaluation 
    sess.run(tf.local_variables_initializer()) 

for _i in range(n_batches_in_test): 
    fd = get_test_batch() 
    outputs = sess.run(test_fetches, feed_dict=fd) 

print("Accuracy:", outputs['accuracy']) 

我的想法是將其添加到tf.GraphKeys.UPDATE_OPS將是有意義的,但我不知道如何做到這一點。

這不是一個好主意,除非您僅使用UPDATE_OPS集合進行測試。通常情況下,集合已經對訓練階段有一定的控制操作(如移動批處理標準化參數),這些操作並不意味着與驗證階段一起運行。最好將它們保存在新的集合中,或手動將這些操作添加到獲取字典中。

+0

好的,我發現使用'tf.GraphKeys.UPDATE_OPS'沒有任何意義,但使用'tf.GraphKeys.SUMMARIES'怎麼辦?我會怎麼做呢? –

+1

'tf.GraphKeys.SUMMARIES',你說?這一個是爲了保留彙總張量,所以你不應該使用它。沒有一個標準的集合鍵,但你可以選擇你自己的一個(例如'metric_ops')。將這些操作集中在一個集合中意味着您將負責確保在評估階段的每次迭代中執行這些操作。 –

+0

我發現了另一個解決方案,你會不會評論? '與tf.control_dependencies([accuracy_op]):tf.summary.scalar('準確性',準確性)' –