2017-06-21 42 views
1

我火車LSTM網絡如何從vanila Tensorflow中的LSTM單元中提取所有權重?

cell_fw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE) 
cell_bw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE) 

rnn_outputs, final_state_fw, final_state_bw = tf.contrib.rnn.static_bidirectional_rnn(
    cell_fw=cell_fw, 
    cell_bw=cell_bw, 
    inputs=rnn_inputs, 
    dtype=tf.float32 
) 

此外,我嘗試將其保存係數:

d = {} 
with tf.Session() as sess: 
    # train code ... 
    variables_names =[v.name for v in tf.global_variables()] 
    values = sess.run(variables_names) 
    for k,v in zip(variables_names, values): 
     d[k] = v 

字典d必須從每個LSTM細胞只有2個對象:

[(k,v.shape) for (k,v) in sorted(d.items(), key=lambda x:x[0])] 
[('bidirectional_rnn/bw/basic_lstm_cell/biases:0', (1024,)), 
('bidirectional_rnn/bw/basic_lstm_cell/weights:0', (272, 1024)), 
('bidirectional_rnn/fw/basic_lstm_cell/biases:0', (1024,)), 
('bidirectional_rnn/fw/basic_lstm_cell/weights:0', (272, 1024)), 
('char_embedding:0', (70, 16)), 
('softmax_biases:0', (5068,)), 
('softmax_weights:0', (5068, 512))] 

我我感到困惑。每個LSTM單元應該包含多達4個可訓練層,或者不是?如果是這樣,如何從LSTM單元獲得所有權重?

回答

1

4個權重(和偏置)一LSTM細胞的被存儲爲單個張量,其中,沿着所述第二軸的切片對應於不同種類的權重的(在柵極,忘記柵極,ECC)

例如,我想你的情況下,HIDDEN_SIZE的值是256

要訪問不同的部分,你應該沿着長度1024的軸切片張量(但我不知道不同種類的權重是以何種順序排列的存儲...)

+0

哦,這是真的。謝謝,我可以放鬆。 – Roosh

相關問題