2017-12-18 119 views
1

從tfrecords文件導入數據時出現問題。在tfrecords每個樣品由feautures矢量與lenght 100和長度13。我使用下面的代碼來導入來自tfrecords數據的一個熱標籤矢量的,指的是正式指南https://www.tensorflow.org/programmers_guide/datasets從tfrecords導入數據時,批處理標籤順序錯誤

def read_data(examples): 
    features = {"features": tf.FixedLenFeature([seq_len], tf.int64), 
       "label": tf.FixedLenFeature([category], tf.int64)} 
    parsed_features = tf.parse_single_example(examples, features) 
    return parsed_features['features'], parsed_features['label'] 

# get next batch of data and label 
def next_batch(filename, batch_size): 
    data = tf.data.TFRecordDataset(filename) 
    data = data.map(read_data) 
    data = data.batch(batch_size) 
    iterator = data.make_one_shot_iterator() 
    next_data, next_label = iterator.get_next() 
    return next_data, next_label 

with tf.Session() as sess: 
    filetrain = 'train.tfrecords' 
    next_data, next_label = next_batch(filetrain, num_example_train) 
    sess.run(tf.global_variables_initializer()) 

    data = sess.run(next_data) 
    label = sess.run(next_label) 

問題批次後標籤的順序會出錯。如果我刪除了代碼'data = data.batch',一切正常。

我認爲一個可能的原因是功能和標籤是獨立分批的。所以我試圖解析批處理後的例子,但得到一個錯誤「輸入序列化必須是標量」。請幫助我,如果你知道如何處理這個問題,非常感謝!

回答

1

我確定這是重複的,但我找不到其他問題,所以我會在這裏回答。

您的問題是撥打sess.run()兩次的數據和標籤。 無論何時您致電sess.run,您的圖表評估爲(即,新的批次被提取並貫穿圖表,直到全部作爲第一個參數傳遞給run的列表中張量的值已知)。

這樣做,您的datalabel是指兩個不同的批次(因此他們看起來錯了)。

你需要讓他們在相同的呼叫與:

data, label = sess.run([next_data, next_label]) 
+0

這正是問題!感謝您的明確解釋! –

+0

不客氣。請將答案標記爲已解決,以便將來能夠鏈接到此重複內容:) – GPhilo