在下面的示例代碼中,輸入形狀爲[2,3,4,5]
,生成的形狀爲[2,3,4]
。
的主要思路是:
- 可以很容易地得到一個排,而不是使用
gather_nd
列的,所以我換過去的兩個維度與tf.transpose
。
- 我們需要將
tf.gather_nd
中從tf.argmax
(indices
)得到的指數轉換爲真正可用的指標(請參見下面的final_idx
)。該轉換是通過堆疊完成的三個組成部分:
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
因此,我們可以從[3, 0]
去
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]].
Batch,Y,X = 2, 3, 4
tf.reset_default_graph()
data = np.arange(Batch*Y*X*5)
np.random.shuffle(data)
Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32)
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]),
[1,Y]), [-1]), tf.int32)
idx = tf.reshape(tf.range(batch_size), [-1,1])
idx = tf.reshape(tf.tile(idx, [1, y]), [-1])
inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1])
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1])
transposed = tf.transpose(Params, [0, 1, 3, 2])
slice = tf.gather_nd(transposed, final_idx)
with tf.Session() as sess:
print sess.run(Params)
print sess.run(idx)
print sess.run(inc)
print sess.run(indices)
print sess.run(final_idx)
print sess.run(slice)
[[[[ 22 38 68 49 119]
[ 47 74 111 117 90]
[ 14 32 31 12 75]
[ 93 34 57 3 56]]
[[ 69 21 4 94 39]
[ 83 96 62 102 80]
[ 55 113 48 98 29]
[107 81 67 76 28]]
[[ 53 51 77 66 63]
[ 92 115 118 116 13]
[ 43 78 15 1 0]
[ 99 50 27 60 73]]]
[[[ 97 88 91 64 86]
[ 72 110 26 87 33]
[ 70 30 41 114 5]
[ 95 82 46 16 61]]
[[109 71 45 8 40]
[101 9 23 59 10]
[ 37 65 44 11 19]
[ 42 104 106 105 18]]
[[112 58 7 17 89]
[ 25 79 103 85 20]
[ 35 6 108 100 36]
[ 24 52 2 54 84]]]]
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]]
[[[ 49 117 12 3]
[ 94 102 98 76]
[ 66 116 1 60]]
[[ 97 72 70 95]
[109 101 37 42]
[112 25 35 24]]]
也許我要澄清:索引張量的值,指定要使用的切片。結果必須在會話仍在運行時獲得,因爲進一步的計算取決於它。我編輯了我原來的帖子,希望現在更清楚。更新了 – Benjamin
。不知道這是你想要的。我稍後會添加一些解釋。 – greeness
太棒了。現在我也明白爲什麼其他例子不起作用。嗯,轉置已知是一個緩慢的操作,我想也可以通過爲第三維生成索引來解決它,並將它們連接起來?可能同樣昂貴。 另一方面 - 如果我們立即將它轉置爲[批處理,N,Y,X],我們可以擺脫其中一種平鋪操作嗎?然後只枚舉批處理,並將其堆疊到索引? – Benjamin