2017-02-09 59 views
0

我正在研究使用LSTM跟蹤參數(時間序列數據迴歸問題)的Tensorflow NN。一批訓練數據包含批量大小爲的連續觀察值。我想使用LSTM狀態作爲下一個樣本的輸入。所以,如果我有一批數據觀測,我想將第一個觀測的狀態作爲第二個觀測的輸入,等等。下面我將lstm狀態定義爲size = batch_size的張量。我想重用內狀態的批次:Tensorflow - 批處理中的LSTM狀態重用

state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False) 
cell = tf.nn.rnn_cell.BasicLSTMCell(100) 
output, curr_state = tf.nn.rnn(cell, data, initial_state=state) 

在API中有一個tf.nn.state_saving_rnn但文檔是有點含糊。 我的問題:如何在培訓批次中重用curr_state

+0

爲了澄清,您希望將第一個批處理元素的結果狀態變爲下一個批處理元素的開始狀態,依此類推?在這種情況下,批量維度是否恰好是時間維度? –

+0

@艾倫拉沃,是的,這是正確的。批次內的每個數據觀察都是(多維)時間序列窗口。該批次包含按順序排列的重疊窗口。批次維度是時間維度,具有重疊和跨度。 – Leeor

+1

在這種情況下,您的批量維度實際上是1.除非您有多個序列可以批處理,否則這將相對較慢。目前正在努力支持允許批量生產更長時間序列的近似值,但尚未公開發布。 –

回答

1

你是基本上沒有,只需要更新statecurr_state

state_update = tf.assign(state, curr_state) 

然後,確保你要麼調用runstate_update本身或具有state_update作爲依賴的操作,或分配將不實際上發生。例如:

with tf.control_dependencies([state_update]): 
    model_output = ... 

作爲評價建議的,爲RNNs的典型的情況是,你有一個批次,其中所述第一尺寸(0)是序列的數目與所述第二維度(1)是最大長度(如果在構建RNN時通過time_major=True這兩個交換)。理想情況下,爲了獲得良好的性能,您可以將多個序列堆疊到一個批處理中,然後按時間分割該批處理。但這真的是一個不同的話題。