2017-03-08 165 views
2

我正在測試mxnet的RNN模型。教程here不起作用,錯誤消息表示許多功能已被棄用。我沒有找到最新的RNN教程。 mxnet項目中還有一些示例。但是對於RNN,examples僅顯示如何使用訓練集來訓練模型。他們沒有展示如何使用訓練好的模型進行進一步的預測。訓練代碼如下:mxnet:如何使用訓練有素的RNN模型進行預測

model.fit(
    train_data   = data_train, 
    eval_data   = data_val, 
    eval_metric   = mx.metric.Perplexity(invalid_label), 
    kvstore    = args.kv_store, 
    optimizer   = args.optimizer, 
    optimizer_params = { 'learning_rate': args.lr, 
          'momentum': args.mom, 
          'wd': args.wd }, 
    initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34), 
    num_epoch   = args.num_epochs, 
    batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches)) 

是否有人知道如何使用的培訓RNN模型作出推斷或預測?

我必須明白,我正在尋找如何使用RNN模型作出預測,而不是CNN或其他模型。

非常感謝您的幫助!

+0

https://github.com/dmlc/mxnet/blob/master/example/rnn/cudnn_lstm_bucketing.py既有列車和測試代碼。這有幫助嗎? –

+2

不可以。但https://github.com/dmlc/mxnet/tree/master/python/mxnet/module中的示例確實有幫助。 – pfc

+1

@pfc如果你找到了答案,你會回答你自己的問題,可能需要相同的幫助嗎? – lynguyen

回答

1

通常模型是擴展BaseModel類。而BaseModel有the method predict。該方法可以使用與fit方法使用的相同類型:DataIter只有一個區別,它不需要train_data,只有eval_data。所以實際的預測過程可以以簡單的方式來實現這樣的:

result = mod.predict(dataiter.next) 
相關問題