2017-02-27 90 views
0

我有以下Tensorflow代碼:張量數據類型爲字符串?

import datetime 
import matplotlib.pyplot as plt 
import numpy as np 
import os 
from PIL import Image 
import tensorflow as tf 

image_width = 202 
image_height = 180 
num_channels = 3 

filenames = tf.train.match_filenames_once("./train/Resized/*.jpg") 

def label(label_string): 
    if label_string == 'cat': label = [1,0] 
    if label_string == 'dog': label = [0,1] 

    return label 

def read_image(filename_queue): 
    image_reader = tf.WholeFileReader() 
    key, image_filename = image_reader.read(filename_queue) 
    image = tf.image.decode_jpeg(image_filename) 
    image.set_shape((image_height, image_width, 3)) 

    name = os.path.basename(image_filename) # example "dog.2148.jpg" 
    s = name.split('.') 
    label_string = s[0] 
    label = label(label_string) 

    return image, label 

def input_pipeline(filenames, batch_size, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True) 
    image, label = read_image(filename_queue) 
    min_after_dequeue = 1000 
    capacity = min_after_dequeue + 3 * batch_size 
    image_batch, label_batch = tf.train.shuffle_batch(
     [image, label], batch_size=batch_size, capacity=capacity, 
     min_after_dequeue=min_after_dequeue) 
    return image_batch, label_batch 

image_batch, label_batch = input_pipeline(filenames, 10) 

的最後一條語句失敗,出現以下錯誤:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-21-0224ec735c33> in <module>() 
----> 1 image_batch, label_batch = input_pipeline(filenames, 10) 

<ipython-input-20-277e29dc1ae3> in input_pipeline(filenames, batch_size, num_epochs) 
     1 def input_pipeline(filenames, batch_size, num_epochs=None): 
     2  filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True) 
----> 3  image, label = read_image(filename_queue) 
     4  min_after_dequeue = 1000 
     5  capacity = min_after_dequeue + 3 * batch_size 

<ipython-input-19-ffe4ec8c3e25> in read_image(filename_queue) 
     5  image.set_shape((image_height, image_width, 3)) 
     6 
----> 7  name = os.path.basename(image_filename) # example "dog.2148.jpg" 
     8  s = name.split('.') 
     9  label_string = s[0] 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in basename(p) 
    230 def basename(p): 
    231  """Returns the final component of a pathname""" 
--> 232  return split(p)[1] 
    233 
    234 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in split(p) 
    202 
    203  seps = _get_bothseps(p) 
--> 204  d, p = splitdrive(p) 
    205  # set i to index beyond p's last slash 
    206  i = len(p) 

C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in splitdrive(p) 
    137 
    138  """ 
--> 139  if len(p) >= 2: 
    140   if isinstance(p, bytes): 
    141    sep = b'\\' 

TypeError: object of type 'Tensor' has no len() 

我認爲這個問題是關係到張量數據類型與字符串數據類型。我如何正確地向os.path.basename函數表明image_filename是一個字符串?

回答

0

的問題是,match_filenames_once返回

A variable that is initialized to the list of files matching pattern.

(在這裏看到:https://www.tensorflow.org/api_docs/python/tf/train/match_filenames_once)。

os.path.basename和string.split是在字符串上工作的函數,不在張量上。

我建議你做的是加載張量管道外的圖像,這使得我認爲你的標籤更容易。

+0

好的,但它也必須可以對Tensorflow管道內的那種變量做一些簡單處理嗎? – OlavT