2016-10-04 88 views
2

我想更新二維張量中索引的值爲0.所以數據是二維張量,其第二行第二列索引值將被替換爲0.但是,我收到了類型錯誤。任何人都可以幫助我嗎?如何更新Tensorflow中2D張量的子集?

TypeError: Input 'ref' of 'ScatterUpdate' Op requires l-value input

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]]) 
data2 = tf.reshape(data, [-1]) 
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0])) 
#data = tf.reshape(data, [N,S]) 
init_op = tf.initialize_all_variables() 

sess = tf.Session() 
sess.run([init_op]) 
print "Values before:", sess.run([data]) 
#sess.run([updated_data_subset]) 
print "Values after:", sess.run([sparse_update]) 

回答

2

散佈更新只適用於變量。請嘗試這種模式。

Tensorflow版本< 1.0: a = tf.concat(0, [a[:i], [updated_value], a[i+1:]])

Tensorflow版本> = 1.0: a = tf.concat(axis=0, values=[a[:i], [updated_value], a[i+1:]])

2

tf.scatter_update只能應用於Variable類型。您的代碼中的dataVariable,而data2不是,因爲返回類型tf.reshapeTensor

解決方案:

爲tensorflow V1.0後

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]]) 
row = tf.gather(data, 2) 
new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0) 
sparse_update = tf.scatter_update(data, tf.constant(2), new_row) 

爲tensorflow V1.0

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]]) 
row = tf.gather(data, 2) 
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]]) 
sparse_update = tf.scatter_update(data, tf.constant(2), new_row) 
+1

它不工作之前:ValueError異常:尺寸0兩形狀必須相同,但是是1和2 \t從合併形狀1 wi其他形狀。 '輸入形狀爲'concat/concat_dim'(op:'Pack'):[2],[1],[2]。 –

+1

@AntonioSesto TensorFlow v1.0改變了一些界面,以保持與numpy一致。 'concat'是其中的一個變化。我編輯了我的答案來使用v1.0。 –