0
我想爲我的模型創建一個新的評估指標(平均互惠等級)。
假設我有:形狀(None, n_class)
的在未知等級的張量上應用函數(平均互惠等級)
logits
張量和含有int
值從0
到n_class-1
形狀(None,)
的y_target
張量。None
將是批量大小。
我想我的輸出爲形狀(None,)
的張量,相應的排名爲y_target
。 首先,我需要對logits
中的元素進行排名,然後獲得索引y_target
中元素的排名,最後得到它的倒數(或者x + 1的倒數,取決於排名過程)。
一個簡單的例子(用於單個觀察):
如果我的y_target=1
和logits=[0.5, -2.0, 1.1, 3.5]
,
然後排名是logits_rank=[3, 4, 2, 1]
和倒數將是1.0/logits_rank[y_target] = 0.25
。
這裏的挑戰是在軸上應用一個函數,因爲排名是未知的(在圖表級別)。 我已經設法使用tf.nn.top_k(logits, k=n_class, sorted=True).indices
獲得一些結果,但只能在session.run(sess, feed_dict)
之內。
任何幫助將不勝感激!