2017-04-13 80 views
1

我訓練與生成的句子如下目的的示範文本: I進料作爲訓練示例2組的序列:X是字符和y的序列,其是由一個相同移。該模型基於LSTM,並使用tensorflow創建。
我的問題是:因爲該模型採取輸入一定規模(在我的案件50)的序列,我怎麼能做出的預測給他只有一個字符的種子?我已經看到了它的一些例子,訓練後,他們產生通過簡單地喂單個字符的句子。
這裏是我的代碼:生成與受過訓練的字符級LSTM模型

with tf.name_scope('input'): 
     x = tf.placeholder(tf.float32, [batch_size, truncated_backprop], name='x') 
     y = tf.placeholder(tf.int32, [batch_size, truncated_backprop], name='y') 

    with tf.name_scope('weights'): 
     W = tf.Variable(np.random.rand(n_hidden, num_classes), dtype=tf.float32) 
     b = tf.Variable(np.random.rand(1, num_classes), dtype=tf.float32) 

    inputs_series = tf.split(x, truncated_backprop, 1) 
    labels_series = tf.unstack(y, axis=1) 

    with tf.name_scope('LSTM'): 
     cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, state_is_tuple=True) 
     cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=dropout) 
     cell = tf.contrib.rnn.MultiRNNCell([cell] * n_layers) 

    states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, \ 
     dtype=tf.float32) 

    logits_series = [tf.matmul(state, W) + b for state in states_series] 
    prediction_series = [tf.nn.softmax(logits) for logits in logits_series] 

    losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) \ 
     for logits, labels, in zip(logits_series, labels_series)] 
    total_loss = tf.reduce_mean(losses) 

    train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) 

回答

2

我建議你使用dynamic_rnn代替static_rnn,這就造成在執行期間的圖形,可以讓你有任何長度的輸入。您輸入佔位符將

x = tf.placeholder(tf.float32, [batch_size, None, features], name='x') 

接下來,你需要一種方法來輸入自己的初始狀態到網絡中。你可以做到這一點的initial_state參數傳遞給dynamic_rnn,如:

initialstate = cell.zero_state(batch_sie, tf.float32) 
outputs, current_state = tf.nn.dynamic_rnn(cell, 
              inputs, 
              initial_state=initialstate) 

就這樣,爲了從單個字符的文本,你可以一次喂圖1的性格,傳遞前一個字符和例如:

prompt = 's' # beginning character, whatever 
inp = one_hot(prompt) # preprocessing, as you probably want to feed one-hot vectors 
state = None 
while True: 
    if state is None: 
     feed = {x: [[inp]]} 
    else: 
     feed = {x: [[inp]], initialstate: state} 

    out, state = sess.run([outputs, current_state], feed_dict=feed) 

    inp = process(out) # extract the predicted character from out and one-hot it 
+0

非常感謝。動態RNN的訣竅非常整齊。現在更清楚了。 – JimZer