2017-02-22 43 views

回答

2

考慮使用tf.TextLineReader,它與tf.train.string_input_producer一起允許您從磁盤上的多個文件(如果您的數據集足夠大以至於需要將其分散到多個文件中)加載數據。

https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files

代碼段從上面的鏈接:

filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

# Default values, in case of empty columns. Also specifies the type of the 
# decoded result. 
record_defaults = [[1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.stack([col1, col2, col3, col4]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for  filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"]) 

reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 

# Default values, in case of empty columns. Also specifies the type of the 
# decoded result. 
record_defaults = [[1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults) 
features = tf.stack([col1, col2, col3, col4]) 

with tf.Session() as sess: 
    # Start populating the filename queue. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(1200): 
    # Retrieve a single instance: 
    example, label = sess.run([features, col5]) 

    coord.request_stop() 
    coord.join(threads)i in range(1200): 
    # Retrieve a single instance: 
    example, label = sess.run([features, col5]) 

    coord.request_stop() 
    coord.join(threads) 
+0

謝謝您的anwser。但是,如果CSV文件中有**列**,該怎麼辦?我必須寫很多col1,col2,col3 ...等等?以及如何從二進制文件讀取數據? – secsilm

+0

@secsilm是的,您需要在您的CSV中爲每列添加「col1」,「col2」等。記住'col1'只是一個變量名,所以你可以給它一個更多的助記符名稱,比如'price'或者其他什麼。有關二進制文件,請參閱https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader – Insectatorious

0

通常情況下,您無論如何都會使用批處理智能培訓,因此您可以即時加載數據。例如,對於圖像:

for bid in nrBatches: 
    batch_x, batch_y = load_data_from_hd(bid) 
    train_step.run(feed_dict={x: batch_x, y_: batch_y}) 

因此,您可以實時加載每個批次,只加載需要在任何特定時刻加載的數據。當然你的訓練時間會增加,而使用硬盤代替內存來加載數據。

相關問題