1

我最近將我的tesnorflow從Rev8升級到Rev12。在Rev8中,rnn_cell.LSTMCell中的默認「state_is_tuple」標誌被設置爲False,所以我用列表初始化了我的LSTM Cell,請參閱下面的代碼。如何使用元組初始化LSTMCell

#model definition 
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim) 
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state) 


#init_state place holder and feed_dict 
def add_placeholders(self): 
    self.init_state = tf.placeholder("float", [None, self.cell_size]) 

def get_feed_dict(self, data, label): 
    feed_dict = {self.input_data: data, 
      self.input_label: reg_label, 
      self.init_state: np.zeros((self.config.batch_size, self.cell_size))} 
    return feed_dict 

在Rev12,默認的「state_is_tuple」標誌被設置爲True,以使我的舊代碼的工作,我不得不把標誌明確轉向爲False。不過,現在我從tensorflow的警告說:「使用級聯狀態較慢,很快就會被棄用 使用state_is_tuple =真」

我試圖初始化一個LSTM細胞元組通過改變佔位符定義self.init_state以下內容:

self.init_state = tf.placeholder("float", (None, self.cell_size)) 

,但現在我得到了一個錯誤信息說:

「‘張量’的對象是不是可迭代」

有誰知道如何使這項工作?

+1

不幸的是,元組是一個複雜的結構。你是否必須*明確地使'init_state'成爲一個佔位符?使用'cell.zero_state'代替它會好得多。別擔心,您可以跨'run_dict'傳遞狀態 – martianwars

回答

1

現在使用cell.zero_state爲LSTM提供「零狀態」要簡單得多。您不需要明確地將初始狀態定義爲佔位符。將其定義爲張量,並根據需要進行填充。這是如何工作的,

lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim) 
self.initial_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32) 
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state) 

如果要喂一些其他的價值作爲初始狀態,假設next_state = states[-1]例如,在您的會話計算,並通過它在feed_dict像 -

feed_dict[self.initial_state] = next_state 

在你的問題中,lstm_cell.zero_state()就足夠了。


不相關,但請記住,您可以在Feed字典中傳遞張量和佔位符!這就是self.initial_state在上面的例子中的工作原理。查看PTB Tutorial的實例。