2017-06-05 87 views
0

這裏有一個問題,我希望我的代碼可以每100步保存一次模型,我的TRAIN_STEPS是3000,所以應該有將近30個模型保存,但最後只有5款車型在檢查站被保存.The細節是:Tensorflow:如何在訓練中需要的步驟中保存模型

model_checkpoint_path: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2900" 
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2500" 
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2600" 
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2700" 
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2800" 
all_model_checkpoint_paths: "/home/vrview/tensorflow/example/char/tfrecords/color/model.ckpt-2900" 

只拯救那些5個models.I不知道why.can有人告訴我?這裏是我的代碼

# coding=utf-8 
from color_1 import read_and_decode, get_batch, get_test_batch 
import color_inference 
import cv2 
import os 
import time 
import numpy as np 
import tensorflow as tf 

batch_size=128 
TRAIN_STEPS=3000 
crop_size=56 
MOVING_AVERAGE_DECAY=0.99 
num_examples=50000 
LEARNING_RATE_BASE=0.8 
LEARNING_RATE_DECAY=0.99 
MODEL_SAVE_PATH="/home/vrview/tensorflow/example/char/tfrecords/color/" 
MODEL_NAME="model.ckpt" 

def train(batch_x,batch_y): 
    image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input') 
    label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input') 
    image_input = tf.reshape(image_holder, [-1, 56, 56, 3]) 

    y=color_inference.inference(image_holder) 
    global_step=tf.Variable(0,trainable=False) 

    def loss(logits, labels): 
     labels = tf.cast(labels, tf.int64) 
     cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=logits, labels=labels, name='cross_entropy_per_example') 

     cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 
     tf.add_to_collection('losses', cross_entropy_mean) 
     return tf.add_n(tf.get_collection('losses'), name='total_loss') 

    loss = loss(y, label_holder) 
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss) 

    saver=tf.train.Saver() 
    init = tf.global_variables_initializer() 
    with tf.Session() as sess: 
     sess.run(init) 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     for i in range(TRAIN_STEPS): 
      image_batch, label_batch = sess.run([batch_x, batch_y]) 
      _, loss_value,step = sess.run([train_op, loss,global_step], feed_dict={image_holder: image_batch, 
                        label_holder:label_batch}) 
      if i % 100 == 0: 
       format_str=('After %d step,loss on training batch is: %.2f') 
       print (format_str%(i,loss_value)) 
       saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=i) 
     coord.request_stop() 
     coord.join(threads) 
def main(argv=None): 
    image, label = read_and_decode('train.tfrecords') 
    batch_image, batch_label = get_batch(image, label, batch_size, crop_size) # batch 生成測試 
    train(batch_image,batch_label) 
if __name__=='__main__': 
    tf.app.run() 

回答

1

添加max_to_keep=30到您的金丹的構造函數,默認情況下它的值是5,這就是爲什麼你不僅節省5倍

+0

副本。非常感謝(^_^) –

相關問題