我最近將我的tesnorflow從Rev8升級到Rev12。在Rev8中,rnn_cell.LSTMCell中的默認「state_is_tuple」標誌被設置爲False,所以我用列表初始化了我的LSTM Cell,請參閱下面的代碼。如何使用元組初始化LSTMCell
#model definition
lstm_cell = rnn_cell.LSTMCell(self.config.hidden_dim)
outputs, states = tf.nn.rnn(lstm_cell, data, initial_state=self.init_state)
#init_state place holder and feed_dict
def add_placeholders(self):
self.init_state = tf.placeholder("float", [None, self.cell_size])
def get_feed_dict(self, data, label):
feed_dict = {self.input_data: data,
self.input_label: reg_label,
self.init_state: np.zeros((self.config.batch_size, self.cell_size))}
return feed_dict
在Rev12,默認的「state_is_tuple」標誌被設置爲True,以使我的舊代碼的工作,我不得不把標誌明確轉向爲False。不過,現在我從tensorflow的警告說:「使用級聯狀態較慢,很快就會被棄用 使用state_is_tuple =真」
我試圖初始化一個LSTM細胞元組通過改變佔位符定義self.init_state以下內容:
self.init_state = tf.placeholder("float", (None, self.cell_size))
,但現在我得到了一個錯誤信息說:
「‘張量’的對象是不是可迭代」
有誰知道如何使這項工作?
不幸的是,元組是一個複雜的結構。你是否必須*明確地使'init_state'成爲一個佔位符?使用'cell.zero_state'代替它會好得多。別擔心,您可以跨'run_dict'傳遞狀態 – martianwars