2017-10-13 220 views
1

我有一個Tensor X whith shape [B,L,E](比方說,B批次的長度爲E的L個向量)。從這張張X中,我想要在每批中隨機選取N個矢量,然後用形狀[B,N,E]創建Y.從Tensorflow中從另一箇中挑選隨機張量

我試圖tf.random_uniform和tf.gather結合,但我真的與尺寸鬥爭並不能得到Y.

回答

2

您可以使用這樣的事情:

import tensorflow as tf 
import numpy as np 

B = 3 
L = 5 
E = 2 
N = 3 

input = np.array(range(B * L * E)).reshape([B, L, E]) 
print(input) 
print("#################################") 

X = tf.constant(input) 
batch_range = tf.tile(tf.reshape(tf.range(B, dtype=tf.int32), shape=[B, 1, 1]), [1, N, 1]) 
random = tf.random_uniform([B, N, 1], minval = 0, maxval = L - 1, dtype = tf.int32) 

indices = tf.concat([batch_range, random], axis = 2) 

output = tf.gather_nd(X, indices) 
with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(indices)) 
    print("#################################") 
    print(sess.run(output))