2

我打算在tflearn模型的hyperparams上執行網格搜索。看來,由tflearn.DNN產生的模式是不與sklearn的GridSearchCV預期兼容:無法使用sklearn的GridSearchCV運行tflearn

from sklearn.grid_search import GridSearchCV 
import tflearn 
import tflearn.datasets.mnist as mnist 
import numpy as np 

X, Y, testX, testY = mnist.load_data(one_hot=True) 

encoder = tflearn.input_data(shape=[None, 784]) 
encoder = tflearn.fully_connected(encoder, 256) 
encoder = tflearn.fully_connected(encoder, 64) 

# Building the decoder 
decoder = tflearn.fully_connected(encoder, 256) 
decoder = tflearn.fully_connected(decoder, 784) 

# Regression, with mean square error 
net = tflearn.regression(decoder, optimizer='adam', learning_rate=0.01, 
         loss='mean_square', metric=None) 

model = tflearn.DNN(net, tensorboard_verbose=0) 

grid_hyperparams = {'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)} 
grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2) 
grid.fit(X, X) 

我得到的錯誤:

TypeError         Traceback (most recent call last) 
<ipython-input-3-fd63245cd0a3> in <module>() 
    22 grid_hyperparams = {'optimizer': ['adam', 'sgd', 'rmsprop'], 'learning_rate': np.logspace(-4, -1, 4)} 
    23 grid = GridSearchCV(model, param_grid=grid_hyperparams, scoring='mean_squared_error', cv=2) 
---> 24 grid.fit(X, X) 
    25 
    26 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in fit(self, X, y) 
    802 
    803   """ 
--> 804   return self._fit(X, y, ParameterGrid(self.param_grid)) 
    805 
    806 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/grid_search.py in _fit(self, X, y, parameter_iterable) 
    539           n_candidates * len(cv))) 
    540 
--> 541   base_estimator = clone(self.estimator) 
    542 
    543   pre_dispatch = self.pre_dispatch 

/home/deeplearning/anaconda3/lib/python3.5/site-packages/sklearn/base.py in clone(estimator, safe) 
    45        "it does not seem to be a scikit-learn estimator " 
    46        "as it does not implement a 'get_params' methods." 
---> 47        % (repr(estimator), type(estimator))) 
    48  klass = estimator.__class__ 
    49  new_object_params = estimator.get_params(deep=False) 

TypeError: Cannot clone object '<tflearn.models.dnn.DNN object at 0x7fead09948d0>' (type <class 'tflearn.models.dnn.DNN'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods. 

任何想法如何,我能得到適合GridSearchCV對象?

回答

0

我對tflearn沒有經驗,但是我確實有一些Python和sklearn的基本背景。從您的StackOverflow屏幕截圖中的錯誤判斷,tflearn ** models **與scikit-learn估計器沒有相同的方法或屬性。這是可以理解的,因爲他們不是,好吧,scikit學習估計。

Sklearn的網格搜索簡歷只適用於與scikit-learn估算器具有相同方法和屬性的對象(例如fit()和predict()方法)。如果你打算使用sklearn的網格搜索,你將不得不圍繞tflearn模型編寫你自己的包裝,以使它作爲sklearn估算器的替代品,這意味着你必須編寫自己的類,它具有相同的方法與其他scikit-learn評估器一樣,但是使用tflearn庫來實際實現這些方法。爲了做到這一點,理解一個基本的scikit-learn估算器(最好你知道的一個)的代碼,並且看看這些方法適用於什麼(),predict(),get_params()等實際上對該對象做了什麼以及它的內部。然後使用tflearn庫編寫你自己的類。

要開始,快速谷歌搜索顯示此存儲庫是「tensorflow框架的薄scikit學習風格包裝」:DSLituiev/tflearn(https://github.com/DSLituiev/tflearn)。我不知道這是否可以替代Grid Search,但值得一看。