2017-04-07 63 views
2

我正在收集每批中張量流的彙總統計。如何累積張量流中的彙總統計

我想收集在測試集上計算出的相同彙總統計信息,但測試集太大而無法在一個批處理中進行處理。

當我遍歷測試集時,是否有一種方便的方式來計算相同的彙總統計信息?

+0

可能重複? https://stackoverflow.com/questions/40788785/how-to-average-summaries-over-multiple-batches/ – Maikefer

+0

這是一個重複的問題,但接受的答案沒有提到流平均包,它現在有轉移到'tf.metrics',這個問題有一個更新的答案,但確實提到了它。 –

回答

1

另一種可能性是累積在測試批次的摘要tensorflow的外側,具有在圖中的虛擬變量,然後可以分配積累的結果。舉個例子:假設你通過幾個批次計算驗證集上的損失,並希望得到平均值的總結。在評估時間

with tf.name_scope('valid_loss'): 
    v_loss = tf.Variable(tf.constant(0.0), trainable=False) 
    self.v_loss_pl = tf.placeholder(tf.float32, shape=[], name='v_loss_pl') 
    self.update_v_loss = tf.assign(v_loss, self.v_loss_pl, name='update_v_loss') 

with tf.name_scope('valid_summaries'): 
    v_loss_s = tf.summary.scalar('validation_loss', v_loss) 
    self.valid_summaries = tf.summary.merge([v_loss_s], name='valid_summaries') 

然後:你可以做下面的方式實現這一目標

total_loss = 0.0 
for batch in all_batches: 
    loss, _ = sess.run([get_loss, ...], feed_dict={...}) 
    total_loss += loss 
total_loss /= float(n_batches) 

[_, v_summary_str] = sess.run([self.update_v_loss, self.valid_summaries], 
           feed_dict={self.v_loss_pl: total_loss}) 
writer.add_summary(v_summary_str) 

雖然這能夠完成任務,但無可否認的感覺有點哈克。您發佈的contrib的流量指標評估可能會更加優雅 - 我從來沒有真正遇到它,所以很想查看它。