2017-03-07 106 views
2

我想從SparseTensor中的一行中獲取所有非零值,因此「m」是我擁有的稀疏張量對象,行是我想要從中獲取所有非零值和索引的行。所以我想返回一個[(index,values)]對的數組]。我希望我能在這個問題上得到一些幫助。在終端獲取SparseTensor的非零排行

def nonzeros(m, row): 
    res = [] 
    indices = m.indices 
    values = m.values 
    userindices = tf.where(tf.equal(indices[:,0], tf.constant(0, dtype=tf.int64))) 
    res = tf.map_fn(lambda index:(indices[index][1], values[index]), userindices) 
    return res 

錯誤消息

TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'. 

編輯: 輸入爲非零 釐米與值

m = tf.SparseTensor(indices=np.array([row,col]).T, 
         values=cm.data, 
         dense_shape=[10, 10]) 
nonzeros(m, 1) 

一個coo_matrix如果數據是

[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.] 
[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 2.] 
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]] 

結果應該是

[index, value] 
[4,1] 
[9,2] 
+0

能否請您舉報的輸入輸出的例子嗎?這樣我們可以更好地理解如何獲得你想要的。 –

回答

2

的問題是,index拉姆達內部是一個張量與不能使用直接索引到例如indices。您可以改用tf.gather。另外,您沒有在發佈的代碼中使用row參數。

試試這個:

import tensorflow as tf 
import numpy as np 

def nonzeros(m, row): 
    indices = m.indices 
    values = m.values 
    userindices = tf.where(tf.equal(indices[:, 0], row)) 
    found_idx = tf.gather(indices, userindices)[:, 0, 1] 
    found_vals = tf.gather(values, userindices)[:, 0:1] 
    res = tf.concat(1, [tf.expand_dims(tf.cast(found_idx, tf.float64), -1), found_vals]) 
    return res 

data = np.array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 1.], 
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 2.]]) 

m = tf.SparseTensor(indices=np.array([[0, 1], [0, 9], [1, 4], [1, 9]]), 
        values=np.array([1.0, 1.0, 1.0, 2.0]), 
        shape=[2, 10]) 

with tf.Session() as sess: 
    result = nonzeros(m, 1) 
    print(sess.run(result)) 

它打印:

[[ 4. 1.] 
[ 9. 2.]]