2017-04-24 82 views
1

這裏的代碼:Tensorflow tf.split()列表索引超出範圍?

a = tf.constant([1,2,3,4]) 
b = tf.constant([4]) 
c = tf.split(a, tf.squeeze(b)) 

那麼,事實證明是錯誤的:

Traceback (most recent call last): 
    File "<stdin>", line 1, in <module> 
    File "/home/jeff/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 1203, in split 
    num = size_splits_shape.dims[0] 
IndexError: list index out of range 

但是,爲什麼?

回答

2

The docs狀態,

如果num_or_size_splits是張量,size_splits,然後分裂成值LEN(size_splits)片。除尺寸爲size_splits [i]的尺寸軸外,第i個尺寸的形狀與尺寸相同。

請注意,size_splits需要是可切片的。

但是當你的squeeze(b),因爲它在你的例子中只有一個元素,它會返回一個沒有尺寸的標量。一個標量不能被切片:

b_ = tf.squeeze(b) 
b_[0] # error 

因此你的錯誤。

+0

感謝您的回答,但是如何將其分成4個部分? –

+0

'tf.split(a,4)',不包含張量中的數字。 – user1735003