2017-02-13 344 views
1

我正在嘗試使用TensorFlow中的scatter_nd函數對矩陣行內的元素重新排序。例如,假設我的代碼:在矩陣行和列中交換元素 - TensorFlow scatter_nd

indices = tf.constant([[1],[0]]) 
updates = tf.constant([ [5, 6, 7, 8], 
         [1, 2, 3, 4] ]) 
shape = tf.constant([2, 4]) 
scatter1 = tf.scatter_nd(indices, updates, shape) 
$ print(scatter1) = [[1,2,3,4] 
        [5,6,7,8]] 

此重新排序updates矩陣的行。

而不是隻能夠重新排序行,我想重新排列每行內的單個元素。如果我只是一個載體(張量等級1),那麼這個例子的工作原理:

indices = tf.constant([[1],[0],[2],[3]]) 
updates = tf.constant([5, 6, 7, 8]) 
shape = tf.constant([4]) 
scatter2 = tf.scatter_nd(indices, updates, shape) 
$ print(scatter2) = [6,5,7,8] 

我真正關心的是能夠爲每行內交換元件scatter1,正如我在scatter2做了,但是對於scatter1的每一行都這樣做。我試過indices的各種組合,但不斷收到大小不一致的錯誤scatter_nd函數拋出。

回答

1

以下使用scatter_nd

indices = tf.constant([[[0, 1], [0, 0], [0, 2], [0, 3]], 
         [[1, 1], [1, 0], [1, 2], [1, 3]]]) 
updates = tf.constant([ [5, 6, 7, 8], 
         [1, 2, 3, 4] ]) 
shape = tf.constant([2, 4]) 
scatter1 = tf.scatter_nd(indices, updates, shape) 
with tf.Session() as sess: 
    print(sess.run(scatter1)) 

給予的輸出交換每一行的每一行中的元素:
[[6 5 7 8] [2 1 3 4]]

的位置在indices座標定義正在採取的值,其中從updates和實際座標定義值將被放置在scatter1

這個答案是晚了幾個月,但希望仍然有幫助。

0

假設您想交換第二維中的元素,保留第一維的順序與否。

import tensorflow as tf 
sess = tf.InteractiveSession() 


def prepare_fd(fd_indices, sd_dims): 
    fd_indices = tf.expand_dims(fd_indices, 1) 
    fd_indices = tf.tile(fd_indices, [1, sd_dims]) 
    return fd_indices 

# define the updates 
updates = tf.constant([[11, 12, 13, 14], 
         [21, 22, 23, 24], 
         [31, 32, 33, 34]]) 
sd_dims = tf.shape(updates)[1] 

sd_indices = tf.constant([[1, 0, 2, 3], [0, 2, 1, 3], [0, 1, 3, 2]]) 
fd_indices_range = tf.range(0, limit=tf.shape(updates)[0]) 
fd_indices_custom = tf.constant([2, 0, 1]) 

# define the indices 
indices1 = tf.stack((prepare_fd(fd_indices_range, sd_dims), sd_indices), axis=2) 
indices2 = tf.stack((prepare_fd(fd_indices_custom, sd_dims), sd_indices), axis=2) 

# define the shape 
shape = tf.shape(updates) 

scatter1 = tf.scatter_nd(indices1, updates, shape) 
scatter2 = tf.scatter_nd(indices2, updates, shape) 

print(scatter1.eval()) 

# array([[12, 11, 13, 14], 
#  [21, 23, 22, 24], 
#  [31, 32, 34, 33]], dtype=int32) 

print(scatter2.eval()) 

# array([[21, 23, 22, 24], 
#  [31, 32, 34, 33], 
#  [12, 11, 13, 14]], dtype=int32) 

可能這個例子有幫助。