0
是否可以使用Keras's scikit-learn API和fit_generator()
方法?或者用另一種方式產生批次進行培訓?我使用SciPy的稀疏矩陣,在輸入到Keras之前必須將其轉換爲NumPy數組,但由於高內存消耗,我無法同時轉換它們。這裏是我的功能以便產生批次:keras/scikit-learn:使用fit_generator()進行交叉驗證
def batch_generator(X, y, batch_size):
n_splits = len(X) // (batch_size - 1)
X = np.array_split(X, n_splits)
y = np.array_split(y, n_splits)
while True:
for i in range(len(X)):
X_batch = []
y_batch = []
for ii in range(len(X[i])):
X_batch.append(X[i][ii].toarray().astype(np.int8)) # conversion sparse matrix -> np.array
y_batch.append(y[i][ii])
yield (np.array(X_batch), np.array(y_batch))
和示例代碼交叉驗證:
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn import datasets
from keras.models import Sequential
from keras.layers import Activation, Dense
from keras.wrappers.scikit_learn import KerasClassifier
import numpy as np
def build_model(n_hidden=32):
model = Sequential([
Dense(n_hidden, input_dim=4),
Activation("relu"),
Dense(n_hidden),
Activation("relu"),
Dense(3),
Activation("sigmoid")
])
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
return model
iris = datasets.load_iris()
X = iris["data"]
y = iris["target"].flatten()
param_grid = {
"n_hidden": np.array([4, 8, 16]),
"nb_epoch": np.array(range(50, 61, 5))
}
model = KerasClassifier(build_fn=build_model, verbose=0)
skf = StratifiedKFold(n_splits=5).split(X, y) # this yields (train_indices, test_indices)
grid = GridSearchCV(model, param_grid, cv=skf, verbose=2, n_jobs=4)
grid.fit(X, y)
print(grid.best_score_)
print(grid.cv_results_["params"][grid.best_index_])
要更多解釋,它使用的超參數所有可能的組合在param_grid
建立一個模型。然後在StratifiedKFold
提供的列車測試數據拆分(摺疊)上對每個模型進行逐一的訓練和測試。然後給定模型的最終得分是所有摺疊的平均得分。
因此,在實際擬合之前,是否可以在上面的代碼中插入一些預處理子步驟來轉換數據(稀疏矩陣)?
我知道我可以編寫自己的交叉驗證生成器,但它必須產生索引,而不是真正的數據!
看起來不錯 - 修改Keras的源代碼也出現在我的腦海中,但我想避免這種情況。謝謝,我會試試:) – jirinovo
所以我修改了一下你的代碼,它工作正常。你有一些想法如何在這裏使用Keras的回調函數[EarlyStopping](https://keras.io/callbacks/#earlystopping)嗎? – jirinovo
當然。不適合評論,但這裏是我的分類器爲Keras和XGB提早停止。曾經工作過但注意不全面測試! https://github.com/simonm3/analysis/blob/master/analysis/classifiers.py – simon