4
我試圖使用TensorFlow複製完全卷積網絡結果。我用Marvin Teichmann's implementation from github。我只需要寫培訓包裝。我創建了兩個共享變量和兩個輸入隊列的圖形,一個用於訓練,一個用於驗證。爲了測試我的培訓包裝,我使用了兩個簡短的培訓和驗證文件列表,並且在每個培訓時期後立即進行驗證。我還從輸入隊列中打印出每個圖像的形狀,以檢查是否得到正確的輸入。但是,在我開始訓練後,似乎只有訓練隊列中的圖像正在排隊。因此,訓練和驗證圖都從訓練隊列中獲取輸入,並且驗證隊列從不被訪問。任何人都可以幫助解釋和解決這個問題?Tensorflow培訓和驗證輸入隊列分隔
下面是相關的代碼的一部分:
def get_data(image_name_list, num_epochs, scope_name, num_class = NUM_CLASS):
with tf.variable_scope(scope_name) as scope:
images_path = [os.path.join(DATASET_DIR, i+'.jpg') for i in image_name_list]
gts_path = [os.path.join(GT_DIR, i+'.png') for i in image_name_list]
seed = random.randint(0, 2147483647)
image_name_queue = tf.train.string_input_producer(images_path, num_epochs=num_epochs, shuffle=False, seed = seed)
gt_name_queue = tf.train.string_input_producer(gts_path, num_epochs=num_epochs, shuffle=False, seed = seed)
reader = tf.WholeFileReader()
image_key, image_value = reader.read(image_name_queue)
my_image = tf.image.decode_jpeg(image_value)
my_image = tf.cast(my_image, tf.float32)
my_image = tf.expand_dims(my_image, 0)
gt_key, gt_value = reader.read(gt_name_queue)
# gt stands for ground truth
my_gt = tf.cast(tf.image.decode_png(gt_value, channels = 1), tf.float32)
my_gt = tf.one_hot(tf.cast(my_gt, tf.int32), NUM_CLASS)
return my_image, my_gt
train_image, train_gt = get_data(train_files, NUM_EPOCH, 'training')
val_image, val_gt = get_data(val_files, NUM_EPOCH, 'validation')
with tf.variable_scope('FCN16') as scope:
train_vgg16_fcn = fcn16_vgg.FCN16VGG()
train_vgg16_fcn.build(train_image, train=True, num_classes=NUM_CLASS, keep_prob = KEEP_PROB)
scope.reuse_variables()
val_vgg16_fcn = fcn16_vgg.FCN16VGG()
val_vgg16_fcn.build(val_image, train=False, num_classes=NUM_CLASS, keep_prob = 1)
"""
Define the loss, evaluation metric, summary, saver in the computation graph. Initialize variables and start a session.
"""
for epoch in range(starting_epoch, NUM_EPOCH):
for i in range(train_num):
_, loss_value, shape = sess.run([train_op, train_entropy_loss, tf.shape(train_image)])
print shape
for i in range(val_num):
loss_value, shape = sess.run([val_entropy_loss, tf.shape(val_image)])
print shape
您是否找到答案? – thigi
我沒有一個好的答案,但建議在單獨的過程中運行評估。它更容易和更清潔。如果您不想這樣做,您可以創建兩個不同的圖表和會話,並將您的驗證輸入隊列與此關聯起來。 –