2017-08-08 91 views
-1

給定一個類實例列表,我需要使用tf.tensor將其索引。例如:如何使用TensorFlow張量索引類實例列表

Class Something(): 
    def __init__(self): 
     self.a = 1 
     self.b = 2 

list = [Something() for a in range(0, 10)] 
index_queue = tf.train.range_input_producer(len(list)) 
index = index_queue.dequeue() 
result = list[index] 
tensor = function_that_returns_tensor(result) 
with tf.Session() as sess: 
    sess.run(tensor) 

上面的代碼給出以下錯誤:TypeError: list indices must be integers, not Tensor

並採用tf.gather(list, index)提供了以下錯誤:

TypeError: Expected binary or unicode string, got <__main__.Something object at 0x7f4529fae2b0> 

任何幫助,將不勝感激。謝謝!

+0

爲什麼你使用'tf.constant(..)'? 'list [2]'會正常工作... –

+0

我修改了這個問題。所以index是一個tf.tensor,它在執行圖時會有一些價值。 –

回答

0

問題出在TensorFlow工作原理的核心機制上。當您調用tf.train.range_input_producer(len(list))tf.constant等TensorFlow方法時,您實際上並不是運行這些操作。您只需將這些操作添加到TensorFlow計算圖。然後您必須使用tf.Session實例的run方法來運行這些操作並從中獲取結果。 TypeError: list indices must be integers, not Tensor告訴您,您將計算圖上的張量引用作爲索引傳遞,而不是運行產生張量的操作返回的結果。請參閱this TensorFlow documentation

+0

非常感謝您的回覆。是的,我瞭解Tensorflow的整體機制。我報告的錯誤是我在tf.Session中運行這些操作時得到的。我相應地修改了這個問題。 –

+0

@UmarIqbal,在您更新的代碼中,您仍然將張量的引用作爲索引傳遞給列表,而不是從運行'tf.Session'返回的內容。在你的代碼中,'index'是一個張量的引用,而不是一個整數。要從它得到一個整數,你需要運行'index_value = sess.run(index)'。然後'list [index_value]'將起作用。 – golmschenk

+0

@UmarIqbal,但是請注意,您的代碼存在另一個問題,它來自使用隊列。如果你做出我說過的改變,你的代碼似乎會掛起。這是因爲隊列需要隊列運行器才能工作。關於[這裏]的更多信息(https://www.tensorflow.org/programmers_guide/threading_and_queues)。但是你最初使用常量(或者其他非隊列生成張量)而不是隊列的例子應該可以工作。 – golmschenk