2017-10-09 138 views
0

我想用graph_cnnDefferrard et al. 2016)來輸入節點數量的變化。作者提供了示例代碼(請參閱graph_cnn)。下面是我認爲的代碼的關鍵部分張量流中graph_cnn的批處理

def chebyshev5(self, x, L, Fout, K): 
    N, M, Fin = x.get_shape() 
    N, M, Fin = int(N), int(M), int(Fin) 
    # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L. 
    L = scipy.sparse.csr_matrix(L) 
    L = graph.rescale_L(L, lmax=2) 
    L = L.tocoo() 
    indices = np.column_stack((L.row, L.col)) 
    L = tf.SparseTensor(indices, L.data, L.shape) 
    L = tf.sparse_reorder(L) 
    # Transform to Chebyshev basis 
    x0 = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 
    x0 = tf.reshape(x0, [M, Fin*N]) # M x Fin*N 
    x = tf.expand_dims(x0, 0) # 1 x M x Fin*N 
    def concat(x, x_): 
     x_ = tf.expand_dims(x_, 0) # 1 x M x Fin*N 
     return tf.concat([x, x_], axis=0) # K x M x Fin*N 
    if K > 1: 
     x1 = tf.sparse_tensor_dense_matmul(L, x0) 
     x = concat(x, x1) 
    for k in range(2, K): 
     x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0 # M x Fin*N 
     x = concat(x, x2) 
     x0, x1 = x1, x2 
    x = tf.reshape(x, [K, M, Fin, N]) # K x M x Fin x N 
    x = tf.transpose(x, perm=[3,1,2,0]) # N x M x Fin x K 
    x = tf.reshape(x, [N*M, Fin*K]) # N*M x Fin*K 
    # Filter: Fin*Fout filters of order K, i.e. one filterbank per feature pair. 
    W = self._weight_variable([Fin*K, Fout], regularization=False) 
    x = tf.matmul(x, W) # N*M x Fout 
    return tf.reshape(x, [N, M, Fout]) # N x M x Fout 

從本質上講,我覺得這樣做有什麼方法可以簡化爲像

return = concat{(L*x)^k for (k=0 to K-1)} * W

xN x M x Fin(大小可變的輸入在任何批次中):

L是一組運算符x,每個運算符的大小爲M x M,與匹配的對應值樣品(大小在任何批次中可變)。

W是要優化的神經網絡參數,它的大小是Fin x K x Fout

N:樣本的批處理(大小固定爲任何批次)號碼;

M:圖中節點的數量(任意批次中的大小變量);

Fin:輸入特徵的數量(大小固定爲任何批次)]。

Fout是輸出特徵的數量(對於任何批次固定的大小)。

K是一個常數,表示在圖表

對於單的示例步驟(跳)的數量,在上述代碼工作。但是,由於xL對於批次中的每個樣品都具有可變長度,所以我不知道如何使其適用於一批樣品。

回答

0

tf.matmul當前(v1.4)僅支持密集張量的最低2個dims上的批量矩陣乘法。如果輸入張量中的任何一個都很稀疏,則會提示尺寸不匹配錯誤。 tf.sparse_tensor_dense_matmul也不能應用於批量輸入。因此,我當前的解決方案是在調用函數之前移動所有L準備步驟,將L作爲稠密張量(形狀:[N,M,M])傳遞,並使用tf.matmul執行批量矩陣乘法。

這是我修改後的代碼:

''' 
chebyshev5_batch 
Purpose: 
    perform the graph filtering on the given layer 
Args: 
    x: the batch of inputs for the given layer, 
     dense tensor, size: [N, M, Fin], 
    L: the batch of sorted Laplacian of the given layer (tf.Tensor) 
     if in dense format, size of [N, M, M] 
    Fout: the number of output features on the given layer 
    K: the filter size or number of hopes on the given layer. 
    lyr_num: the idx of the original Laplacian lyr (start form 0) 
Output: 
    y: the filtered output from the given layer 

''' 
def chebyshev5_batch(x, L, Fout, K, lyr_num): 
    N, M, Fin = x.get_shape() 
    #N, M, Fin = int(N), int(M), int(Fin) 
# # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L. 
# L = scipy.sparse.csr_matrix(L) 
# L = graph.rescale_L(L, lmax=2) 
# L = L.tocoo() 
# indices = np.column_stack((L.row, L.col)) 
# L = tf.SparseTensor(indices, L.data, L.shape) 
# L = tf.sparse_reorder(L) 
# # Transform to Chebyshev basis 
# x0 = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 
# x0 = tf.reshape(x0, [M, Fin*N]) # M x Fin*N 

    def expand_concat(orig, new): 
     new = tf.expand_dims(new, 0) # 1 x N x M x Fin 
     return tf.concat([orig, new], axis=0) # (shape(x)[0] + 1) x N x M x Fin 

    # L: # N x M x M 
    # x0: # N x M x Fin 
    # L*x0: # N x M x Fin 

    x0 = x # N x M x Fin 
    stk_x = tf.expand_dims(x0, axis=0) # 1 x N x M x Fin (eventually K x N x M x Fin, if K>1) 

    if K > 1: 
     x1 = tf.matmul(L, x0) # N x M x Fin 
     stk_x = expand_concat(stk_x, x1) 
    for kk in range(2, K): 
     x2 = tf.matmul(L, x1) - x0 # N x M x Fin 
     stk_x = expand_concat(stk_x, x2) 
     x0 = x1 
     x1 = x2 

    # now stk_x has the shape of K x N x M x Fin 
    # transpose to the shape of N x M x Fin x K 
    ## source positions   1 2 3  0 
    stk_x_transp = tf.transpose(stk_x, perm=[1,2,3,0]) 
    stk_x_forMul = tf.reshape(stk_x_transp, [N*M, Fin*K]) 


    #W = self._weight_variable([Fin*K, Fout], regularization=False) 
    W_initial = tf.truncated_normal_initializer(0, 0.1) 
    W = tf.get_variable('weights_L_'+str(lyr_num), [Fin*K, Fout], tf.float32, initializer=W_initial) 
    tf.summary.histogram(W.op.name, W) 

    y = tf.matmul(stk_x_forMul, W) 
    y = tf.reshape(y, [N, M, Fout]) 
    return y