0
我有以下Tensorflow代碼:張量數據類型爲字符串?
import datetime
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
import tensorflow as tf
image_width = 202
image_height = 180
num_channels = 3
filenames = tf.train.match_filenames_once("./train/Resized/*.jpg")
def label(label_string):
if label_string == 'cat': label = [1,0]
if label_string == 'dog': label = [0,1]
return label
def read_image(filename_queue):
image_reader = tf.WholeFileReader()
key, image_filename = image_reader.read(filename_queue)
image = tf.image.decode_jpeg(image_filename)
image.set_shape((image_height, image_width, 3))
name = os.path.basename(image_filename) # example "dog.2148.jpg"
s = name.split('.')
label_string = s[0]
label = label(label_string)
return image, label
def input_pipeline(filenames, batch_size, num_epochs=None):
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
image, label = read_image(filename_queue)
min_after_dequeue = 1000
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return image_batch, label_batch
image_batch, label_batch = input_pipeline(filenames, 10)
的最後一條語句失敗,出現以下錯誤:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-21-0224ec735c33> in <module>()
----> 1 image_batch, label_batch = input_pipeline(filenames, 10)
<ipython-input-20-277e29dc1ae3> in input_pipeline(filenames, batch_size, num_epochs)
1 def input_pipeline(filenames, batch_size, num_epochs=None):
2 filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)
----> 3 image, label = read_image(filename_queue)
4 min_after_dequeue = 1000
5 capacity = min_after_dequeue + 3 * batch_size
<ipython-input-19-ffe4ec8c3e25> in read_image(filename_queue)
5 image.set_shape((image_height, image_width, 3))
6
----> 7 name = os.path.basename(image_filename) # example "dog.2148.jpg"
8 s = name.split('.')
9 label_string = s[0]
C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in basename(p)
230 def basename(p):
231 """Returns the final component of a pathname"""
--> 232 return split(p)[1]
233
234
C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in split(p)
202
203 seps = _get_bothseps(p)
--> 204 d, p = splitdrive(p)
205 # set i to index beyond p's last slash
206 i = len(p)
C:\local\Anaconda3-4.1.1-Windows-x86_64\envs\cntk-py35\lib\ntpath.py in splitdrive(p)
137
138 """
--> 139 if len(p) >= 2:
140 if isinstance(p, bytes):
141 sep = b'\\'
TypeError: object of type 'Tensor' has no len()
我認爲這個問題是關係到張量數據類型與字符串數據類型。我如何正確地向os.path.basename函數表明image_filename是一個字符串?
好的,但它也必須可以對Tensorflow管道內的那種變量做一些簡單處理嗎? – OlavT