我有一個TensorFlow問題,當我使用輸入丟失時,LSTM的性能急劇下降(從70%降至< 10%)。測試失敗導致LSTM性能故障
據我所知,我應該將input_keep_probability設置爲(例如)在訓練期間爲0.5,然後在測試期間設爲1。這非常合理,但我無法按預期工作。如果我在測試過程中設置了退出,就像在訓練過程中一樣,我的性能得到了提高(這在下例中不是這種情況,但這需要更少的代碼,點數也是一樣的)。
The accuracy and cost of 3 runs
的 '最佳' 路線是沒有下降現象,最壞的行是[keep_prob @火車:0.5,keep_prob @測試:1]和中心線爲[keep_prob @火車:0.5,keep_prob @測試:0.5]。這些都是測試裝置的成本和準確性。他們在火車上的表現如預期一樣。
下面是我認爲很重要的代碼。可悲的是,由於其敏感性,我無法發佈完整的代碼或數據樣本,但請評論是否需要更多信息。
lstm_size = 512
numLayers = 4
numSteps = 15
lstm_cell = tf.nn.rnn_cell.LSTMCell(lstm_size, state_is_tuple=True, forget_bias=1)
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, input_keep_prob=input_keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * numLayers, state_is_tuple=True)
_inputs = [tf.squeeze(s, [1]) for s in tf.split(1, numSteps, input_data)]
(outputs, state) = rnn.rnn(cell, _inputs, dtype=tf.float32)
outputs = tf.pack(outputs)
#transpose so I can put all timesteps through the softmax at once
outputsTranspose = tf.reshape(outputs, [-1, lstm_size])
softmax_w = tf.get_variable("softmax_w", [lstm_size, nof_classes])
softmax_b = tf.get_variable("softmax_b", [nof_classes])
logits = tf.matmul(outputsTranspose, softmax_w) + softmax_b
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
cost = tf.reduce_mean(loss)
targetPrediction = tf.argmax(logits, 1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(targetPrediction, targets), "float"))
"""Optimizer"""
with tf.name_scope("Optimizer") as scope:
tvars = tf.trainable_variables()
#We clip the gradients to prevent explosion
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),maxGradNorm)
optimizer = tf.train.AdamOptimizer(learning_rate)
gradients = zip(grads, tvars)
train_op = optimizer.apply_gradients(gradients)
with tf.Session() as sess:
sess.run(init_op);
for i in range(nofBatches * nofEpochs):
example_batch, label_batch = sess.run(readTrainDataOp)
result = sess.run([train_op, accuracy, trainSummaries], feed_dict = {input_data: example_batch, targets: label_batch, input_keep_prob:trainInputKeepProbability, batch_size_ph:batch_size})
#logging
if i % 50 == 0:
runTestSet()
#relevant part of runTestSet():
#result = sess.run([cost, accuracy], feed_dict = {input_data: testData, targets: testLabels, input_keep_prob:testInputKeepProbability, batch_size_ph:testBatchSize})
#logging
我在做什麼錯誤,造成這種意想不到的行爲?
編輯:這是輸入樣本的樣子的圖像請參閱下一個鏈接。
該問題也只存在於1層。編輯:I made an example that reproduces the problem。只需運行python腳本,將test_samples.npy的路徑作爲第一個參數,並將檢查點的路徑作爲第二個參數。
您可以用少量的重現該問題以及OM我LSTM工作代碼和一些可以共享的測試數據?這將有助於我們更有效地調試問題。 –
這很短 – Julius
對不起,響應緩慢,我錯誤地禁用了我的通知。我添加了對該問題的輸入圖像。我還會寫一些代碼來讀取一個檢查點,並在幾個測試樣本上運行帶有或不帶缺失的測試。 –