2016-11-28 904 views
0

是否可以使用Keras's scikit-learn APIfit_generator()方法?或者用另一種方式產生批次進行培訓?我使用SciPy的稀疏矩陣,在輸入到Keras之前必須將其轉換爲NumPy數組,但由於高內存消耗,我無法同時轉換它們。這裏是我的功能以便產生批次:keras/scikit-learn:使用fit_generator()進行交叉驗證

def batch_generator(X, y, batch_size): 
    n_splits = len(X) // (batch_size - 1) 
    X = np.array_split(X, n_splits) 
    y = np.array_split(y, n_splits) 

    while True: 
     for i in range(len(X)): 
      X_batch = [] 
      y_batch = [] 
      for ii in range(len(X[i])): 
       X_batch.append(X[i][ii].toarray().astype(np.int8)) # conversion sparse matrix -> np.array 
       y_batch.append(y[i][ii]) 
      yield (np.array(X_batch), np.array(y_batch)) 

和示例代碼交叉驗證:

from sklearn.model_selection import StratifiedKFold, GridSearchCV 
from sklearn import datasets 

from keras.models import Sequential 
from keras.layers import Activation, Dense 
from keras.wrappers.scikit_learn import KerasClassifier 

import numpy as np 


def build_model(n_hidden=32): 
    model = Sequential([ 
     Dense(n_hidden, input_dim=4), 
     Activation("relu"), 
     Dense(n_hidden), 
     Activation("relu"), 
     Dense(3), 
     Activation("sigmoid") 
    ]) 
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]) 
    return model 


iris = datasets.load_iris() 
X = iris["data"] 
y = iris["target"].flatten() 

param_grid = { 
    "n_hidden": np.array([4, 8, 16]), 
    "nb_epoch": np.array(range(50, 61, 5)) 
} 

model = KerasClassifier(build_fn=build_model, verbose=0) 
skf = StratifiedKFold(n_splits=5).split(X, y) # this yields (train_indices, test_indices) 

grid = GridSearchCV(model, param_grid, cv=skf, verbose=2, n_jobs=4) 
grid.fit(X, y) 

print(grid.best_score_) 
print(grid.cv_results_["params"][grid.best_index_]) 

要更多解釋,它使用的超參數所有可能的組合在param_grid建立一個模型。然後在StratifiedKFold提供的列車測試數據拆分(摺疊)上對每個模型進行逐一的訓練和測試。然後給定模型的最終得分是所有摺疊的平均得分。

因此,在實際擬合之前,是否可以在上面的代碼中插入一些預處理子步驟來轉換數據(稀疏矩陣)?

我知道我可以編寫自己的交叉驗證生成器,但它必須產生索引,而不是真正的數據!

回答

1

實際上,您可以使用稀疏矩陣作爲Keras的生成器的輸入。這是我以前的項目的版本:

> class KerasClassifier(KerasClassifier): 
>  """ adds sparse matrix handling using batch generator 
>  """ 
>  
>  def fit(self, x, y, **kwargs): 
>   """ adds sparse matrix handling """ 
>   if not issparse(x): 
>    return super().fit(x, y, **kwargs) 
>   
>   ############ adapted from KerasClassifier.fit ###################### 
>   if self.build_fn is None: 
>    self.model = self.__call__(**self.filter_sk_params(self.__call__)) 
>   elif not isinstance(self.build_fn, types.FunctionType): 
>    self.model = self.build_fn(
>     **self.filter_sk_params(self.build_fn.__call__)) 
>   else: 
>    self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) 
> 
>   loss_name = self.model.loss 
>   if hasattr(loss_name, '__name__'): 
>    loss_name = loss_name.__name__ 
>   if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: 
>    y = to_categorical(y) 
>   ### fit => fit_generator 
>   fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit_generator)) 
>   fit_args.update(kwargs) 
>   ############################################################ 
>   self.model.fit_generator(
>      self.get_batch(x, y, self.sk_params["batch_size"]), 
>           samples_per_epoch=x.shape[0], 
>           **fit_args)      
>   return self        
> 
>  def get_batch(self, x, y=None, batch_size=32): 
>   """ batch generator to enable sparse input """ 
>   index = np.arange(x.shape[0]) 
>   start = 0 
>   while True: 
>    if start == 0 and y is not None: 
>     np.random.shuffle(index) 
>    batch = index[start:start+batch_size] 
>    if y is not None: 
>     yield x[batch].toarray(), y[batch] 
>    else: 
>     yield x[batch].toarray() 
>    start += batch_size 
>    if start >= x.shape[0]: 
>     start = 0 
> 
>  def predict_proba(self, x): 
>   """ adds sparse matrix handling """ 
>   if not issparse(x): 
>    return super().predict_proba(x) 
>    
>   preds = self.model.predict_generator(
>      self.get_batch(x, None, self.sk_params["batch_size"]), 
>            val_samples=x.shape[0]) 
>   return preds 
+0

看起來不錯 - 修改Keras的源代碼也出現在我的腦海中,但我想避免這種情況。謝謝,我會試試:) – jirinovo

+0

所以我修改了一下你的代碼,它工作正常。你有一些想法如何在這裏使用Keras的回調函數[EarlyStopping](https://keras.io/callbacks/#earlystopping)嗎? – jirinovo

+1

當然。不適合評論,但這裏是我的分類器爲Keras和XGB提早停止。曾經工作過但注意不全面測試! https://github.com/simonm3/analysis/blob/master/analysis/classifiers.py – simon