我想使用一個無法加載到內存中的大型數據集來訓練帶有TensorFlow的模型。但我不知道我應該做什麼。如果我想要使用無法使用TensorFlow加載到內存中的大型數據集,該怎麼辦?
我已閱讀了一些關於TFRecords
文件格式和官方文檔的好帖子。公交車我仍然無法弄清楚。
TensorFlow是否有完整的解決方案?
我想使用一個無法加載到內存中的大型數據集來訓練帶有TensorFlow的模型。但我不知道我應該做什麼。如果我想要使用無法使用TensorFlow加載到內存中的大型數據集,該怎麼辦?
我已閱讀了一些關於TFRecords
文件格式和官方文檔的好帖子。公交車我仍然無法弄清楚。
TensorFlow是否有完整的解決方案?
考慮使用tf.TextLineReader
,它與tf.train.string_input_producer
一起允許您從磁盤上的多個文件(如果您的數據集足夠大以至於需要將其分散到多個文件中)加載數據。
見https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
代碼段從上面的鏈接:
filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for filename_queue = tf.train.string_input_producer(["file0.csv", "file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
# Default values, in case of empty columns. Also specifies the type of the
# decoded result.
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(
value, record_defaults=record_defaults)
features = tf.stack([col1, col2, col3, col4])
with tf.Session() as sess:
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)i in range(1200):
# Retrieve a single instance:
example, label = sess.run([features, col5])
coord.request_stop()
coord.join(threads)
通常情況下,您無論如何都會使用批處理智能培訓,因此您可以即時加載數據。例如,對於圖像:
for bid in nrBatches:
batch_x, batch_y = load_data_from_hd(bid)
train_step.run(feed_dict={x: batch_x, y_: batch_y})
因此,您可以實時加載每個批次,只加載需要在任何特定時刻加載的數據。當然你的訓練時間會增加,而使用硬盤代替內存來加載數據。
謝謝您的anwser。但是,如果CSV文件中有**列**,該怎麼辦?我必須寫很多col1,col2,col3 ...等等?以及如何從二進制文件讀取數據? – secsilm
@secsilm是的,您需要在您的CSV中爲每列添加「col1」,「col2」等。記住'col1'只是一個變量名,所以你可以給它一個更多的助記符名稱,比如'price'或者其他什麼。有關二進制文件,請參閱https://www.tensorflow.org/api_docs/python/tf/FixedLengthRecordReader – Insectatorious