2016-09-25 90 views
5

如果A每一行的越來越要素是像這樣TensorFlow特定列

A = tf.Variable([[1, 2], [3, 4]]) 

index一個TensorFlow變量是另一個變量

index = tf.Variable([0, 1]) 

我想用這個指標來選擇每列行。在這種情況下,第一行中的項目0和第二行中的項目1。

如果A是一個numpy的數組,然後獲得在指數中提到相應的行的列,我們可以做

x = A[np.arange(A.shape[0]), index] 

,其結果將是

[1, 4] 

什麼是TensorFlow相當於操作/爲此操作?我知道TensorFlow不支持許多索引操作。如果不能直接完成,會有什麼工作呢?

回答

2

經過了相當長的一段時間。我發現了兩個可能有用的功能。

一個是tf.gather_nd()這可能是有用的,如果你能生產形式[[0, 0], [1, 1]]的張量 ,從而你可以做

index = tf.constant([[0, 0], [1, 1]])

tf.gather_nd(A, index)

如果你不能產生的矢量形式[[0, 0], [1, 1]](由於某些原因,我無法產生這種情況,因爲我的案例中的行數取決於佔位符),那麼我發現的解決方法是使用tf.py_func()。下面是關於如何可以做到這一點

import tensorflow as tf 
import numpy as np 

def index_along_every_row(array, index): 
    N, _ = array.shape 
    return array[np.arange(N), index] 

a = tf.Variable([[1, 2], [3, 4]], dtype=tf.int32) 
index = tf.Variable([0, 1], dtype=tf.int32) 
a_slice_op = tf.py_func(index_along_every_row, [a, index], [tf.int32])[0] 
session = tf.InteractiveSession() 

a.initializer.run() 
index.initializer.run() 
a_slice = a_slice_op.eval() 

a_slice將是一個numpy的陣列[1, 4]

0

我們可以使用map_fngather_nd這個組合做同樣的示例代碼。

def get_element(a, indices): 
    """ 
    Outputs (ith element of indices) from (ith row of a) 
    """ 
    return tf.map_fn(lambda x: tf.gather_nd(x[0], x[1]), 
            (a, indices), 
            dtype = tf.float32) 

下面是一個示例用法。

A = tf.constant(np.array([[1,2,3], 
          [4,5,6], 
          [7,8,9]], dtype = np.float32)) 

idx = tf.constant(np.array([[2],[1],[0]])) 
elems = get_element(A, idx) 

with tf.Session() as sess: 
    e = sess.run(elems) 

print(e) 

我不知道這是否會比其他答案慢得多。

它的優點是,您無需事先指定A的行數,只要aindices在運行時具有相同的行數。

注意的的輸出將是排名1.如果你喜歡它有2級,由gather

2

更換gather_nd您可以使用one hot方法來創建一個one_hot陣列,並把它作爲一個布爾掩碼來選擇你想要的索引。

A = tf.Variable([[1, 2], [3, 4]]) 
index = tf.Variable([0, 1]) 

one_hot_mask = tf.one_hot(index, A.shape[1], on_value = True, off_value = False, dtype = tf.bool) 
output = tf.boolean_mask(A, one_hot_mask) 
+0

雖然這可能與問題,當鏈路死了,也沒有發生什麼事的解釋幫助。你應該形成你自己的答案,解釋發生的事情,而不要依賴SO上不允許的鏈接作爲唯一的信息來源。 https://stackoverflow.com/help/how-to-answer – Rob

1

您可以通過行索引擴展您的列索引,然後使用gather_nd:

import tensorflow as tf 

A = tf.constant([[1, 2], [3, 4]]) 
indices = tf.constant([1, 0]) 

# prepare row indices 
row_indices = tf.range(tf.shape(indices)[0]) 

# zip row indices with column indices 
full_indices = tf.stack([row_indices, indices], axis=1) 

# retrieve values by indices 
S = tf.gather_nd(A, full_indices) 

session = tf.InteractiveSession() 
session.run(S)