2017-10-16 65 views
-1

我有一個4D-張量尺寸[B,Y,X,N]的PARAMS,並希望從它選擇一個特定的切片n ∈ N,使得我的所得張量是大小爲[B,Y,X,1](或[B,Y,X])。Tensorflow收集特定元素在圖4d張量

特定的切片應該是包含最高數字的平均切片;我得到的指數,像這樣:
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)(形狀[B])

我嘗試了不同的解決方案,使用gathergather_nd,但不能讓它開始工作。有很多帖子與此非常相似,但我無法應用其中提供的解決方案之一。

我正在運行Tensorflow 1.3,因此gather的花式新軸參數可用。

回答

0

在下面的示例代碼中,輸入形狀爲[2,3,4,5],生成的形狀爲[2,3,4]

的主要思路是:

  • 可以很容易地得到一個排,而不是使用gather_nd列的,所以我換過去的兩個維度與tf.transpose
  • 我們需要將tf.gather_nd中從tf.argmaxindices)得到的指數轉換爲真正可用的指標(請參見下面的final_idx)。該轉換是通過堆疊完成的三個組成部分:
    • [0 0 0 1 1 1]
    • [0 1 2 0 1 2]
    • [3 3 3 0 0 0]

因此,我們可以從[3, 0]

[[[0 0 3] 
    [0 1 3] 
    [0 2 3]] 
    [[1 0 0] 
    [1 1 0] 
    [1 2 0]]]. 
Batch,Y,X = 2, 3, 4 
tf.reset_default_graph() 

data = np.arange(Batch*Y*X*5) 
np.random.shuffle(data) 

Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32) 
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1) 
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]), 
            [1,Y]), [-1]), tf.int32) 

idx = tf.reshape(tf.range(batch_size), [-1,1]) 
idx = tf.reshape(tf.tile(idx, [1, y]), [-1]) 

inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1]) 
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1]) 
transposed = tf.transpose(Params, [0, 1, 3, 2]) 
slice = tf.gather_nd(transposed, final_idx) 

with tf.Session() as sess: 
    print sess.run(Params) 
    print sess.run(idx)  
    print sess.run(inc) 
    print sess.run(indices) 
    print sess.run(final_idx) 
    print sess.run(slice) 
[[[[ 22 38 68 49 119] 
    [ 47 74 111 117 90] 
    [ 14 32 31 12 75] 
    [ 93 34 57 3 56]] 

    [[ 69 21 4 94 39] 
    [ 83 96 62 102 80] 
    [ 55 113 48 98 29] 
    [107 81 67 76 28]] 

    [[ 53 51 77 66 63] 
    [ 92 115 118 116 13] 
    [ 43 78 15 1 0] 
    [ 99 50 27 60 73]]] 


[[[ 97 88 91 64 86] 
    [ 72 110 26 87 33] 
    [ 70 30 41 114 5] 
    [ 95 82 46 16 61]] 

    [[109 71 45 8 40] 
    [101 9 23 59 10] 
    [ 37 65 44 11 19] 
    [ 42 104 106 105 18]] 

    [[112 58 7 17 89] 
    [ 25 79 103 85 20] 
    [ 35 6 108 100 36] 
    [ 24 52 2 54 84]]]] 

[0 0 0 1 1 1] 
[0 1 2 0 1 2] 
[3 3 3 0 0 0] 

[[[0 0 3] 
    [0 1 3] 
    [0 2 3]] 

[[1 0 0] 
    [1 1 0] 
    [1 2 0]]] 

[[[ 49 117 12 3] 
    [ 94 102 98 76] 
    [ 66 116 1 60]] 

[[ 97 72 70 95] 
    [109 101 37 42] 
    [112 25 35 24]]] 
+0

也許我要澄清:索引張量的值,指定要使用的切片。結果必須在會話仍在運行時獲得,因爲進一步的計算取決於它。我編輯了我原來的帖子,希望現在更清楚。更新了 – Benjamin

+0

。不知道這是你想要的。我稍後會添加一些解釋。 – greeness

+0

太棒了。現在我也明白爲什麼其他例子不起作用。嗯,轉置已知是一個緩慢的操作,我想也可以通過爲第三維生成索引來解決它,並將它們連接起來?可能同樣昂貴。 另一方面 - 如果我們立即將它轉置爲[批處理,N,Y,X],我們可以擺脫其中一種平鋪操作嗎?然後只枚舉批處理,並將其堆疊到索引? – Benjamin