0
我正在關注來自_tensorflow.org的this教程。 我試圖正確處理input_fn_,作爲參數.fit()。 我創建了分類:如何創建一個輸入函數,input_fn()
classifier = tf.contrib.learn.SKCompat(tf.contrib.learn.DNNClassifier(
feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir=("C:\\........\tmp"),
n_classes=2,
activation_fn=tf.sigmoid,
optimizer=tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
)))
然後輸入功能:
def input_fn(data_set):
feature_cols = {k: tf.constant(data_set[k].values)
for k in FEATURES}
labels = tf.constant(data_set[LABEL].values)
return feature_cols, labels
最後我已經把input_fn()在配合()由:
classifier.fit(input_fn=lambda: input_fn(training_set), steps=10)
當我運行代碼時,出現此錯誤:
TypeError Traceback (most recent call last)
<ipython-input-6-938bcd2f929f> in <module>()
----> 1 classifier.fit(input_fn=lambda: input_fn(training_set), steps=10)
TypeError: fit() got an unexpected keyword argument 'input_fn'
我不知道這是否是對input_fn定義或適合參數