2017-10-15 77 views
0

我想從tensorflow中使用tf.train.shuffle_batch函數,然後我需要先使用tf.image.decode_jpeg(或其他類似的函數來加載png和jpg)加載圖像。但是我發現圖像被加載爲概率圖,這意味着像素值的最大值爲1,像素值的最小值爲0.下面是我從github回購庫更新的代碼。我不知道爲什麼像素的值被歸一化爲[0,1],並且我沒有找到張量流的相關文檔。任何人都可以幫我嗎?謝謝。爲tf.image.decode_jpeg和tf.train.shuffle_batch規範化了圖像像素值?

def load_examples(self, input_dir, flip, scale_size, batch_size, min_queue_examples): 
    input_paths = get_image_paths(input_dir) 
    with tf.name_scope("load_images"): 
     path_queue = tf.train.string_input_producer(input_paths) 
     reader = tf.WholeFileReader() 
     paths, contents = reader.read(path_queue) 
     # note this is important for truncated images 
     raw_input = tf.image.decode_jpeg(contents,try_recover_truncated = True, acceptable_fraction=0.5) 
     raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) 
     raw_input.set_shape([None, None, 3]) 

     # break apart image pair and move to range [-1, 1] 
     width = tf.shape(raw_input)[1] # [height, width, channels] 
     a_images = preprocess(raw_input[:, :width // 2, :]) 
     b_images = raw_input[:, width // 2:, :] 

    inputs, targets = [a_images, b_images] 

    def transform(image): 
     r = image 

     r = tf.image.resize_images(r, [self.image_height, self.image_width], method=tf.image.ResizeMethod.AREA) 
     return r 
    def transform_gaze(image): 
     r = image 
     r = tf.image.resize_images(r, [self.gaze_height, self.gaze_width], method=tf.image.ResizeMethod.AREA) 
     return r 
    with tf.name_scope("input_images"): 
     input_images = transform(inputs) 

    with tf.name_scope("target_images"): 
     target_images = transform(targets) 
    total_image_count = len(input_paths) 
    # target_images = tf.image.per_image_standardization(target_images) 
    target_images = target_images[:,:,0] 
    target_images = tf.expand_dims(target_images, 2) 
    inputs_batch, targets_batch = tf.train.shuffle_batch([input_images, target_images], 
             batch_size=batch_size, 
             num_threads=1, 
             capacity=min_queue_examples + 3 * batch_size, 
             min_after_dequeue=min_queue_examples) 
    # inputs_batch, targets_batch = tf.train.batch([input_images, target_images],batch_size=batch_size) 
    return inputs_batch, targets_batch, total_image_count 

回答

2

由於是tf.image.decode_*方法所做的值,因此值爲[0,1]。一般來說,當一個方法返回一個浮點張量時,它的值應該在[0,1]範圍內,而如果返回的張量是一個uint8,則該值應該在[0,255]範圍內。

此外,當您使用tf.image.convert_image_dtype方法來轉換輸入圖像的dtype時,您正在應用該轉換規則。

如果您的輸入圖像是一個uint8圖像,並將其轉換爲float32,則這些值將縮放到[0,1]範圍內。如果你的圖像已經是一個浮點數,那麼它的值應該在這個範圍內,而且什麼都不做。

+0

嗨我還有一個問題,我添加輸入數據的圖像摘要,就像這樣:tf.summary.image('training_truth',self.targets,4)它在我看來,在張量板,圖像顯示在[0,255]範圍內。那麼這是否意味着對我的模型的圖像批處理被標準化,而張量板可視化仍然是[0,255]?謝謝 –

+0

是的,圖像彙總檢查輸入類型。如果它是浮動的,那麼它會將這些值縮放到0.255範圍內,以便可視化 – nessuno

+0

太棒了,謝謝你的回答! –