2016-12-25 160 views
1

我已經使用keras對ANN分類器進行了編碼,現在我正在學習自己編寫用於文本和時間序列預測的keras中的RNN。在網上搜索了一段時間後,我發現了Jason Brownlee的tutorial,這對於RNN的初學者來說是一個很好的選擇。原文將IMDb數據集用於LSTM文本分類,但由於其數據集大小較大,我將其更改爲小型sms垃圾郵件檢測數據集。如何在數據集中使用keras RNN進行文本分類?

# LSTM with dropout for sequence classification in the IMDB dataset 
import numpy 
from keras.datasets import imdb 
from keras.models import Sequential 
from keras.layers import Dense 
from keras.layers import LSTM 
from keras.layers.embeddings import Embedding 
from keras.preprocessing import sequence 
import pandaas as pd 
from sklearn.cross_validation import train_test_split 

# fix random seed for reproducibility 
numpy.random.seed(7) 

url = 'https://raw.githubusercontent.com/justmarkham/pydata-dc-2016-tutorial/master/sms.tsv' 
sms = pd.read_table(url, header=None, names=['label', 'message']) 

# convert label to a numerical variable 
sms['label_num'] = sms.label.map({'ham':0, 'spam':1}) 
X = sms.message 
y = sms.label_num 
print(X.shape) 
print(y.shape) 

# load the dataset 
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1) 
top_words = 5000 

# truncate and pad input sequences 
max_review_length = 500 
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length) 
X_test = sequence.pad_sequences(X_test, maxlen=max_review_length) 

# create the model 
embedding_vecor_length = 32 
model = Sequential() 
model.add(Embedding(top_words, embedding_vecor_length, input_length=max_review_length, dropout=0.2)) 
model.add(LSTM(100, dropout_W=0.2, dropout_U=0.2)) 
model.add(Dense(1, activation='sigmoid')) 
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 
print(model.summary()) 
model.fit(X_train, y_train, nb_epoch=3, batch_size=64) 

# Final evaluation of the model 
scores = model.evaluate(X_test, y_test, verbose=0) 
print("Accuracy: %.2f%%" % (scores[1]*100)) 

我已經成功地將數據集處理成了訓練和測試集,但現在應如何爲此數據集建立我的RNN模型?

回答

1

在訓練神經網絡模型之前,您需要將raw text數據表示爲numeric vector。爲此,您可以使用scikit-learn提供的CountVectorizerTfidfVectorizer。從原始文本格式轉換爲數字向量表示形式後,您可以訓練RNN/LSTM/CNN進行文本分類問題。

相關問題