2017-05-08 152 views
2

我試圖生產用於TensorArray的組合while_loop一個很簡單的例子:TensorArray和while_loop如何在tensorflow中一起工作?

# 1000 sequence in the length of 100 
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix") 
matrix_rows = tf.shape(matrix)[0] 
ta = tf.TensorArray(tf.float32, size=matrix_rows) 
ta = ta.unstack(matrix) 

init_state = (0, ta) 
condition = lambda i, _: i < n 
body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2)) 

# run the graph 
with tf.Session() as sess: 
    (n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))}) 
    print (ta_final.stack()) 

但我收到以下錯誤:

ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32). 

任何人有想法是什麼問題?

+0

得到最終的'TensorArray'你需要爲'session.run(ta.stack())',而不是運行循環直接哪個失敗,因爲你不能'session.run(TensorArray)'。 – sirfz

+0

對不起,我沒有明白你的意思。你能寫出正確的表格嗎? –

回答

3

代碼中有幾點需要指出。首先,您不需要將矩陣拆散到TensorArray以在循環內部使用它,則可以安全地在主體內引用矩陣Tensor,並使用matrix[i]表示法對其進行索引。另一個問題是矩陣(tf.int32)和TensorArraytf.float32)之間的數據類型不同,根據您將矩陣乘以2的代碼並將結果寫入數組,因此它也應該是int32。最後,當你想讀取循環的最終結果時,正確的操作是TensorArray.stack(),這是你需要在你的session.run調用中運行。

這裏有一個工作示例:

import numpy as np 
import tensorflow as tf  

# 1000 sequence in the length of 100 
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix") 
matrix_rows = tf.shape(matrix)[0] 
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows) 

init_state = (0, ta) 
condition = lambda i, _: i < matrix_rows 
body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2)) 
n, ta_final = tf.while_loop(condition, body, init_state) 
# get the final result 
ta_final_result = ta_final.stack() 

# run the graph 
with tf.Session() as sess: 
    # print the output of ta_final_result 
    print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)}) 
+0

太棒了!非常感謝你!它現在有效。 –

+0

很高興有效@ E.Asgari,請將答案標記爲已接受。 – sirfz

+0

在這可以指定輸入,而不使用飼料字典,因爲如果我在計算圖之間使用它,我將如何指定張量數組依賴於張量? – Rahul