2016-08-03 223 views
0

我已經張量的定義如下:如何在Tensorflow中從張量中獲取特定行?

idx = tf.constant([0, 2]) 

現在我想利用temp_var一個子集在那些:

temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]])) 

我也有行索引的陣列,以從張量中獲取指標即idx

我知道,要採取單一索引或切片,我們可以做這樣的事情

temp_var[single_row_index, :] 

temp_var[start:end, :] 

但如何讀取行由idx陣列表示? 類似於temp_var[idx, :]

回答

2

tf.gather() op正好滿足您的需求:它從矩陣(或從N維張量中選擇一般(N-1)維片)中選擇行。以下是它如何在你的情況下工作:

temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])) 
idx = tf.constant([0, 2]) 

rows = tf.gather(temp_var, idx) 

init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 

print(sess.run(rows)) # ==> [[1, 2, 3], [7, 8, 9]]