2017-05-05 70 views
1

讓我們考慮一個numpy的矩陣,o越來越多指數值從張量一次,在tensorflow

如果我們要使用numpy的使用下面的功能:

o[np.arange(x), column_array] 

我能得到來自一個numpy數組的多個索引。

我試圖用tensorflow做同樣的事情,但它不像我所做的那樣工作。當o是張量流張量時;

o[tf.range(0, x, 1), column_array] 

我得到以下錯誤:

TypeError: can only concatenate list (not "int") to list 

我能做些什麼?

回答

1

您可能希望看到tf.gather_ndhttps://www.tensorflow.org/api_docs/python/tf/gather_nd

import tensorflow as tf 
import numpy as np 

tensor = tf.placeholder(tf.float32, [2,2]) 
indices = tf.placeholder(tf.int32, [2,2]) 
selected = tf.gather_nd(tensor, indices=indices) 

with tf.Session() as session: 
    data = np.array([[0.1,0.2],[0.3,0.4]]) 
    idx = np.array([[0,0],[1,1]]) 
    result = session.run(selected, feed_dict={indices:idx, tensor:data}) 
    print(result) 

,其結果將是[ 0.1 0.40000001]

3

你可以試試tf.gather_nd(),爲How to select rows from a 3-D Tensor in TensorFlow?這篇文章建議。 以下是從矩陣o獲取多個索引的示例。

o = tf.constant([[1, 2, 3, 4], 
       [5, 6, 7, 8], 
       [9, 10, 11, 12], 
       [13, 14, 15, 16]]) 
# [row_index, column_index], I don’t figure out how to 
# combine row vector and column vector into this form. 
indices = tf.constant([[0, 0], [0, 1], [2, 1], [2, 3]]) 

result = tf.gather_nd(o, indices) 

with tf.Session() as sess: 
    print(sess.run(result)) #[ 1 2 10 12]