我正在嘗試使用TensorFlow中的scatter_nd函數對矩陣行內的元素重新排序。例如,假設我的代碼:在矩陣行和列中交換元素 - TensorFlow scatter_nd
indices = tf.constant([[1],[0]])
updates = tf.constant([ [5, 6, 7, 8],
[1, 2, 3, 4] ])
shape = tf.constant([2, 4])
scatter1 = tf.scatter_nd(indices, updates, shape)
$ print(scatter1) = [[1,2,3,4]
[5,6,7,8]]
此重新排序updates
矩陣的行。
而不是隻能夠重新排序行,我想重新排列每行內的單個元素。如果我只是一個載體(張量等級1),那麼這個例子的工作原理:
indices = tf.constant([[1],[0],[2],[3]])
updates = tf.constant([5, 6, 7, 8])
shape = tf.constant([4])
scatter2 = tf.scatter_nd(indices, updates, shape)
$ print(scatter2) = [6,5,7,8]
我真正關心的是能夠爲每行內交換元件scatter1
,正如我在scatter2
做了,但是對於scatter1
的每一行都這樣做。我試過indices
的各種組合,但不斷收到大小不一致的錯誤scatter_nd
函數拋出。