我正在研究使用LSTM跟蹤參數(時間序列數據迴歸問題)的Tensorflow NN。一批訓練數據包含批量大小爲的連續觀察值。我想使用LSTM狀態作爲下一個樣本的輸入。所以,如果我有一批數據觀測,我想將第一個觀測的狀態作爲第二個觀測的輸入,等等。下面我將lstm狀態定義爲size = batch_size的張量。我想重用內狀態的批次:Tensorflow - 批處理中的LSTM狀態重用
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
cell = tf.nn.rnn_cell.BasicLSTMCell(100)
output, curr_state = tf.nn.rnn(cell, data, initial_state=state)
在API中有一個tf.nn.state_saving_rnn但文檔是有點含糊。 我的問題:如何在培訓批次中重用curr_state 。
爲了澄清,您希望將第一個批處理元素的結果狀態變爲下一個批處理元素的開始狀態,依此類推?在這種情況下,批量維度是否恰好是時間維度? –
@艾倫拉沃,是的,這是正確的。批次內的每個數據觀察都是(多維)時間序列窗口。該批次包含按順序排列的重疊窗口。批次維度是時間維度,具有重疊和跨度。 – Leeor
在這種情況下,您的批量維度實際上是1.除非您有多個序列可以批處理,否則這將相對較慢。目前正在努力支持允許批量生產更長時間序列的近似值,但尚未公開發布。 –