2017-01-09 78 views
0

我正在關注來自_tensorflow.org的this教程。 我試圖正確處理input_fn_,作爲參數.fit()。 我創建了分類:如何創建一個輸入函數,input_fn()

classifier = tf.contrib.learn.SKCompat(tf.contrib.learn.DNNClassifier(
feature_columns=feature_cols, 
hidden_units=[10, 10], 
model_dir=("C:\\........\tmp"), 
n_classes=2, 
activation_fn=tf.sigmoid, 
optimizer=tf.train.ProximalAdagradOptimizer(
    learning_rate=0.1, 
    l1_regularization_strength=0.001 
    ))) 

然後輸入功能:

def input_fn(data_set): 
    feature_cols = {k: tf.constant(data_set[k].values) 
        for k in FEATURES} 
    labels = tf.constant(data_set[LABEL].values) 
    return feature_cols, labels 

最後我已經把input_fn()配合()由:

classifier.fit(input_fn=lambda: input_fn(training_set), steps=10) 

當我運行代碼時,出現此錯誤:

TypeError         Traceback (most recent call last) 
<ipython-input-6-938bcd2f929f> in <module>() 
----> 1 classifier.fit(input_fn=lambda: input_fn(training_set), steps=10) 

TypeError: fit() got an unexpected keyword argument 'input_fn' 

我不知道這是否是對input_fn定義或適合參數

回答

0

,如果你想使用input_fn,取代第一線與不使用SKCompat:

classifier = tf.contrib.learn.DNNClassifier(

並根據需要調整括號。