2017-01-23 125 views
0

我使用tensorflow有關python 我的形狀的數據張量[?,5,37],和形狀的IDX張量[?,5]提取從張量的特定元素在tensorflow

我倒要提取的數據元素,並獲得的形狀的輸出,使得[,5']:

output[i][j] = data[i][j][idx[i, j]] for all i in range(?) and j in range(5) 

它看起來洛克的tf.gather_nd()函數是最接近我的需要,但我不t看看我的情況如何使用它...

謝謝!

編輯:我設法做到了與gather_nd如下所示,但有沒有更好的選擇? (似乎有點重手)

nRows = tf.shape(length_label)[0] ==> ? 
    nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) ==> 5 
    m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]), 
              shape=[nRows, nCols]) 
    m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]), 
              shape=[nCols, nRows])) 
    indices = tf.pack([m2, m1, idx], axis=-1) 
    # indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]] 
    output = tf.gather_nd(data, indices=indices) 
+0

您的解決方案對我來說很好。 – user1454804

回答

0

我設法與gather_nd做如下圖所示

nRows = tf.shape(length_label)[0] # ==> ? 
nCols = tf.constant(MAX_LENGTH_INPUT + 1, dtype=tf.int32) # ==> 5 
m1 = tf.reshape(tf.tile(tf.range(nCols), [nRows]), 
             shape=[nRows, nCols]) 
m2 = tf.transpose(tf.reshape(tf.tile(tf.range(nRows), [nCols]), 
             shape=[nCols, nRows])) 
indices = tf.pack([m2, m1, idx], axis=-1) 
# indices should be of shape [?, 5, 3] with indices[i,j]==[i,j,idx[i,j]] 
output = tf.gather_nd(data, indices=indices)