從他們抄寫單詞的圖像列表,我想使用tf.train.slice_input_producer
創建和讀取稀疏序列標籤(用於tf.nn.ctc_loss
),避免如何在Tensorflow中爲CTC丟失生成/讀取稀疏序列標籤?
在
TFRecord
格式序列化預包裝的訓練數據到磁盤的
tf.py_func
表觀侷限性,任何不必要或過早填充和
將整個數據集讀取到RAM中。
的主要問題似乎是一個字符串被轉換爲標籤的需要tf.nn.ctc_loss
的序列(SparseTensor
)。
例如,在字符集(有序)範圍[A-Z]
中,我想將文本標籤字符串"BAD"
轉換爲序列標籤類列表[1,0,3]
。
每個示例圖像我想讀包含文本的文件名的一部分,所以它是直接的提取和做轉化率直線上升蟒蛇。 (如果有一種方法來計算TensorFlow內做到這一點,我還沒有發現它。)
以前的幾個問題,瞄了一眼這些問題,但我一直沒能對他們成功整合。例如,
Tensorflow read images with labels 顯示與分立,分類標籤, 我已經開始與作爲模型一個簡單的框架。
How to load sparse data with TensorFlow? 很好地解釋了用於裝載稀疏數據的方法,但假定 預包裝
tf.train.Example
秒。
有沒有辦法整合這些方法?
另一個例子(SO問題#38012743)顯示了我如何推遲從字符串到列表的轉換,直到解除文件名出隊權之後,但它依賴於tf.py_func
,它有一些注意事項。 (我應該擔心它們嗎?)
我認識到「SparseTensors不能很好地處理隊列」(每個tf文檔),所以在批處理之前可能需要對結果做一些voodoo(序列化?) ,甚至在計算髮生的地方返工;我對此表示歡迎。
按照MarvMind的提綱,這是一個基本框架,包含我想要的計算(遍歷包含示例文件名的行,提取每個標籤字符串並轉換爲序列),但是我沒有成功確定「Tensorflow」 。
謝謝你正確的「調整」,對我的目標來說是一個更合適的策略,或者指示tf.py_func
不會破壞培訓效率或下游的其他東西(例如,,加載訓練有素的模型以供將來使用)。
編輯(+7小時)我找到了缺少的操作來修補東西了。雖然仍然需要驗證這與CTC_Loss下游連接,但我已檢查以下編輯的版本是否正確批量並讀取圖像和稀疏張量。
out_charset="ABCDEFGHIJKLMNOPQRSTUVWXYZ"
def input_pipeline(data_filename):
filenames,seq_labels = _get_image_filenames_labels(data_filename)
data_queue = tf.train.slice_input_producer([filenames, seq_labels])
image,label = _read_data_format(data_queue)
image,label = tf.train.batch([image,label],batch_size=2,dynamic_pad=True)
label = tf.deserialize_many_sparse(label,tf.int32)
return image,label
def _get_image_filenames_labels(data_filename):
filenames = []
labels = []
with open(data_filename)) as f:
for line in f:
# Carve out the ground truth string and file path from
# lines formatted like:
# ./241/7/158_NETWORK_51375.jpg 51375
filename = line.split(' ',1)[0][2:] # split off "./" and number
# Extract label string embedded within image filename
# between underscores, e.g. NETWORK
text = os.path.basename(filename).split('_',2)[1]
# Transform string text to sequence of indices using charset, e.g.,
# NETWORK -> [13, 4, 19, 22, 14, 17, 10]
indices = [[i] for i in range(0,len(text))]
values = [out_charset.index(c) for c in list(text)]
shape = [len(text)]
label = tf.SparseTensorValue(indices,values,shape)
label = tf.convert_to_tensor_or_sparse_tensor(label)
label = tf.serialize_sparse(label) # needed for batching
# Add data to lists for conversion
filenames.append(filename)
labels.append(label)
filenames = tf.convert_to_tensor(filenames)
labels = tf.convert_to_tensor_or_sparse_tensor(labels)
return filenames, labels
def _read_data_format(data_queue):
label = data_queue[1]
raw_image = tf.read_file(data_queue[0])
image = tf.image.decode_jpeg(raw_image,channels=1)
return image,label