,或者你只是使用模型定義的updated Estimator API of TensorFlow r1.1
的API與參數,只返回類型或函數名的一些小變化頗爲相似。下面是我用一個例子:
def model_fn():
def _build_model(features, labels, mode, params):
# 1. Configure the model via TensorFlow operations
# Connect the first hidden layer to input layer (features) with relu activation
y = tf.contrib.layers.fully_connected(features, num_outputs=64, activation_fn=tf.nn.relu,
weights_initializer=tf.contrib.layers.xavier_initializer())
y = tf.contrib.layers.fully_connected(y, num_outputs=64, activation_fn=tf.nn.relu,
weights_initializer=tf.contrib.layers.xavier_initializer())
y = tf.contrib.layers.fully_connected(y, num_outputs=1, activation_fn=tf.nn.sigmoid,
weights_initializer=tf.contrib.layers.xavier_initializer())
predictions = y
# 2. Define the loss function for training/evaluation
if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
loss = tf.reduce_mean((predictions - labels) ** 2)
else:
loss = None
if mode != tf.estimator.ModeKeys.PREDICT:
eval_metric_ops = {
"rmse": tf.metrics.root_mean_squared_error(tf.cast(labels, tf.float32), predictions),
"accuracy": tf.metrics.accuracy(tf.cast(labels, tf.float32), predictions),
"precision": tf.metrics.precision(tf.cast(labels, tf.float32), predictions)
}
else:
eval_metric_ops = None
# 3. Define the training operation/optimizer
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=0.001,
optimizer="Adam")
else:
train_op = None
if mode == tf.estimator.ModeKeys.PREDICT:
predictions_dict = {"pred": predictions}
else:
predictions_dict = None
# 5. Return predictions/loss/train_op/eval_metric_ops in ModelFnOps object
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
return _build_model
而且你可以使用這個模型則是這樣的:
e = tf.estimator.Estimator(model_fn=model_fn(), params=None)
e.train(input_fn=input_fn(), steps=1000)
的輸入功能TensorFlow R1.1的例子中可以找到我回答here。
我跟隨[cnn_mnist教程](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/layers/cnn_mnist.py)有類似的問題。基於錯誤信息,我嘗試了類似'tensorflow.contrib.learn.SKCompat import SKCompat'之類的東西,並用'SKCompat()'包裝Estimator。但它不起作用......錯誤:「沒有名爲SKCompat的模塊」。也需要一些幫助! – user3768495