從tfrecords文件導入數據時出現問題。在tfrecords每個樣品由feautures矢量與lenght 100和長度13。我使用下面的代碼來導入來自tfrecords數據的一個熱標籤矢量的,指的是正式指南https://www.tensorflow.org/programmers_guide/datasets從tfrecords導入數據時,批處理標籤順序錯誤
def read_data(examples):
features = {"features": tf.FixedLenFeature([seq_len], tf.int64),
"label": tf.FixedLenFeature([category], tf.int64)}
parsed_features = tf.parse_single_example(examples, features)
return parsed_features['features'], parsed_features['label']
# get next batch of data and label
def next_batch(filename, batch_size):
data = tf.data.TFRecordDataset(filename)
data = data.map(read_data)
data = data.batch(batch_size)
iterator = data.make_one_shot_iterator()
next_data, next_label = iterator.get_next()
return next_data, next_label
with tf.Session() as sess:
filetrain = 'train.tfrecords'
next_data, next_label = next_batch(filetrain, num_example_train)
sess.run(tf.global_variables_initializer())
data = sess.run(next_data)
label = sess.run(next_label)
問題批次後標籤的順序會出錯。如果我刪除了代碼'data = data.batch',一切正常。
我認爲一個可能的原因是功能和標籤是獨立分批的。所以我試圖解析批處理後的例子,但得到一個錯誤「輸入序列化必須是標量」。請幫助我,如果你知道如何處理這個問題,非常感謝!
這正是問題!感謝您的明確解釋! –
不客氣。請將答案標記爲已解決,以便將來能夠鏈接到此重複內容:) – GPhilo