2017-04-05 47 views
0

我想爲我的模型創建一個新的評估指標(平均互惠等級)。
假設我有:形狀(None, n_class)在未知等級的張量上應用函數(平均互惠等級)

  • logits張量和含有int值從0n_class-1形狀(None,)
  • y_target張量。
  • None將是批量大小。

我想我的輸出爲形狀(None,)的張量,相應的排名爲y_target。 首先,我需要對logits中的元素進行排名,然後獲得索引y_target中元素的排名,最後得到它的倒數(或者x + 1的倒數,取決於排名過程)。

一個簡單的例子(用於單個觀察):
如果我的y_target=1logits=[0.5, -2.0, 1.1, 3.5]
然後排名是logits_rank=[3, 4, 2, 1]
和倒數將是1.0/logits_rank[y_target] = 0.25

這裏的挑戰是在軸上應用一個函數,因爲排名是未知的(在圖表級別)。 我已經設法使用tf.nn.top_k(logits, k=n_class, sorted=True).indices獲得一些結果,但只能在session.run(sess, feed_dict)之內。

任何幫助將不勝感激!

回答

0

解決!

def tf_get_rank_order(input, reciprocal): 
    """ 
    Returns a tensor of the rank of the input tensor's elements. 
    rank(highest element) = 1. 
    """ 
    assert isinstance(reciprocal, bool), 'reciprocal has to be bool' 
    size = tf.size(input) 
    indices_of_ranks = tf.nn.top_k(-input, k=size)[1] 
    indices_of_ranks = size - tf.nn.top_k(-indices_of_ranks, k=size)[1] 
    if reciprocal: 
     indices_of_ranks = tf.cast(indices_of_ranks, tf.float32) 
     indices_of_ranks = tf.map_fn(
      lambda x: tf.reciprocal(x), indices_of_ranks, 
      dtype=tf.float32) 
     return indices_of_ranks 
    else: 
     return indices_of_ranks 


def get_reciprocal_rank(logits, targets, reciprocal=True): 
    """ 
    Returns a tensor containing the (reciprocal) ranks 
    of the logits tensor (wrt the targets tensor). 
    The targets tensor should be a 'one hot' vector 
    (otherwise apply one_hot on targets, such that index_mask is a one_hot). 
    """ 
    function_to_map = lambda x: tf_get_rank_order(x, reciprocal=reciprocal) 
    ordered_array_dtype = tf.float32 if reciprocal is not None else tf.int32 
    ordered_array = tf.map_fn(function_to_map, logits, 
           dtype=ordered_array_dtype) 

    size = int(logits.shape[1]) 
    index_mask = tf.reshape(
      targets, [-1,size]) 
    if reciprocal: 
     index_mask = tf.cast(index_mask, tf.float32) 

    return tf.reduce_sum(ordered_array * index_mask,1) 

# use: 
recip_rank = tf.reduce_mean(
       get_reciprocal_rank(logits[-1], 
            y_, 
            True)