2017-04-07 93 views
1

我在玩map_fn函數,注意到它輸出一個TensorArray,這意味着它能夠輸出「鋸齒」張量(其中內部的張量具有不同的第一維度。tensorflow map_fn TensorArray的形狀不一致

我試圖看到這個動作與此代碼:

import tensorflow as tf 
import numpy as np 

NUM_ARRAYS = 1000 
MAX_LENGTH = 1000 

lengths = tf.placeholder(tf.int32) 
tArray = tf.map_fn(lambda x: tf.random_normal((x,), 0, 1), 
        lengths, 
        dtype=tf.float32) # Should return a TensorArray. 

# startTensor = tf.random_normal((tf.reduce_sum(lengths),), 0, 1) 
# tArray = tf.TensorArray(tf.float32, NUM_ARRAYS) 
# tArray = tArray.split(startTensor, lengths) 
# outArray = tArray.concat() 


with tf.Session() as sess: 
    outputArray, l = sess.run(
     [tArray, lengths], 
     feed_dict={lengths: np.random.randint(MAX_LENGTH, size=NUM_ARRAYS)}) 
    print outputArray.shape, l 

然而得到了錯誤:

「TensorArray具有不一致的形狀索引0具有形狀:[259],但指數1具有形狀:[773]「

這對我來說當然是一個驚喜,因爲我的印象是TensorArrays應該能夠處理它。我錯了嗎?

回答

4

雖然tf.map_fn()確實使用tf.TensorArray對象內部tf.TensorArray可容納不同大小的物體,這個程序將無法正常工作,是因爲tf.map_fn()通過堆疊的元素一起轉換其tf.TensorArray結果回tf.Tensor,這是失敗的操作。

但是,您可以實現tf.TensorArray使用較低槓桿tf.while_loop()運算,而不是基礎的:

lengths = tf.placeholder(tf.int32) 
num_elems = tf.shape(lengths)[0] 
init_array = tf.TensorArray(tf.float32, size=num_elems) 

def loop_body(i, ta): 
    return i + 1, ta.write(i, tf.random_normal((lengths[i],), 0, 1)) 

_, result_array = tf.while_loop(
    lambda i, ta: i < num_elems, loop_body, [0, init_array])