2015-07-13 101 views
3

我有這樣的numpy的代碼:`uniq`二維Theano張

def uniq(seq): 
    """ 
    Like Unix tool uniq. Removes repeated entries. 
    :param seq: numpy.array. (time,) -> label 
    :return: seq 
    """ 
    diffs = np.ones_like(seq) 
    diffs[1:] = seq[1:] - seq[:-1] 
    idx = diffs.nonzero() 
    return seq[idx] 

現在,我要擴展這個支持二維數組,並使其使用Theano。它應該在GPU上很快。

我將以格式(時間,批次)的形式獲取具有多個序列的多個批次的數組,以及間接指定每個序列的長度的time_mask

我目前的嘗試:

def uniq_with_lengths(seq, time_mask): 
    # seq is (time,batch) -> label 
    # time_mask is (time,batch) -> 0 or 1 
    num_batches = seq.shape[1] 
    diffs = T.ones_like(seq) 
    diffs = T.set_subtensor(diffs[1:], seq[1:] - seq[:-1]) 
    time_range = T.arange(seq.shape[0]).dimshuffle([0] + ['x'] * (seq.ndim - 1)) 
    idx = T.switch(T.neq(diffs, 0) * time_mask, time_range, -1) 
    seq_lens = T.sum(T.ge(idx, 0), axis=0) # (batch,) -> len 
    max_seq_len = T.max(seq_lens) 

    # I don't know any better way without scan. 
    def step(batch_idx, out_seq_b1): 
    out_seq = seq[T.ge(idx[:, batch_idx], 0).nonzero(), batch_idx][0] 
    return T.concatenate((out_seq, T.zeros((max_seq_len - out_seq.shape[0],), dtype=seq.dtype))) 

out_seqs, _ = theano.scan(
    step, 
    sequences=[T.arange(num_batches)], 
    outputs_info=[T.zeros((max_seq_len,), dtype=seq.dtype)] 
) 
    # out_seqs is (batch,max_seq_len) 
    return out_seqs.T, seq_lens 

如何直接構造out_seqs

我會做一些像out_seqs = seq[idx]但我不完全確定如何表達。

回答

0

這裏有一個快速的答案,只有解決了你的任務的一部分:

def compile_theano_uniq(x): 
    diffs = x[1:] - x[:-1] 
    diffs = tt.concatenate([tt.ones_like([x[0]], dtype=diffs.dtype), diffs]) 
    y = diffs.nonzero_values() 
    return theano.function(inputs=[x], outputs=y) 

theano_uniq = compile_theano_uniq(tt.vector(dtype='int32')) 

的關鍵是nonzero_values()

更新:我無法想象任何方式做到這一點,而不使用theano.scan。需要明確的是,使用0作爲填充,我假設給定的輸入

1 1 2 3 3 4 0 
1 2 2 2 3 3 4 
1 2 3 4 5 0 0 

你想輸出爲

1 2 3 4 0 0 0 
1 2 3 4 0 0 0 
1 2 3 4 5 0 0 

甚至

1 2 3 4 0 
1 2 3 4 0 
1 2 3 4 5 

你可以在不使用掃描的情況下確定要保留的項目的索引。然後,無論是從頭開始構建一個新的張量,還是需要保留一些如何移動以使序列連續的值。沒有theano.scan,這兩種方法都不可行。

+0

這似乎是一個扁平化的數組。我不知道如何從那裏得到我想要的。 – Albert

+0

已更新的答案。我認爲你的新掃描方法可能是最好的。一種非掃描的方法可能是可能的,但我懷疑這將需要很長時間才能弄清楚,並且將來很難理解/維護。 –