Keras新手在這裏。我在一個非常大的CSV文件上做了一些深入的學習實驗(keras 2.x,tensorflow作爲背景,python3.5)。如何創建一個在Keras模式下讀取一個巨型數據框的線程安全生成器fit_generator
將CSV加載到Pandas數據框後,我需要讀取數據幀以將數據轉換爲X_train,y_train/label。因爲轉換後的X_train非常大,不適合內存。我開始使用generator和model.fit_generator()。我已經瞭解到,通過創建一個線程安全的生成器,我可以使用多個工作器,並使用use_multiprocessing = True,以便更高效。然而,在我的情況下,在內部生成器中它總是讀取相同的數據幀,我想知道如何使它成爲線程安全的,因爲相同的數據/行不會被多個生成器實例讀取並生成?沒有線程安全
我的電流發生器的實現是這樣的:
data = pd.read_csv("data.csv", header=0, delimiter="\t", quoting=3, encoding="utf-8")
y = data.label
X_train, X_test, y_train, y_test = train_test_split(data, y, test_size=0.2)
def data_genereator(data, batch_size):
num_rows = int(data.shape[0])
# Initialize a counter
counter = 0
while True:
for content, label in zip(data['content'], data['label']):
X_train[counter%batch_size] = transform(content)
y_train[counter%batch_size] = np.asarray(label)
counter = counter + 1
if(counter%batch_size == 0):
yield X_train, y_train
training_generator = data_genereator(X_train, batch_size=1024)
validation_generator = data_genereator(X_test, batch_size=1024)
model = Sequential()
model.add(LSTM(64, input_shape=(1000, 2400), return_sequences=False,
kernel_initializer='he_normal', dropout=0.15, recurrent_dropout=0.15, implementation=2))
model.add(Dropout(0.3))
model.add(Dense(1, activation = 'sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit_generator(training_generator,
steps_per_epoch=8000,
validation_data=validation_generator,
epochs=3,
verbose=1,
workers=1,
use_multiprocessing=False,
validation_steps=2000)
我可能是完全錯誤的,但想要得到我的作品和use_multiprocessing參數的瞭解您的反饋,是多個生成器實例(如生產商)將被啓動以將數據饋送到由model.fit_generator()函數創建/維護的隊列中,同時將數據從隊列中抓取到GPU以用於訓練(消費者)。如果使用GPU進行培訓不是瓶頸,那麼發電機可以生產/生產的數據越多,整個過程就會越快。我默認了max_queue_size = 10,一旦生成器是線程安全的,如何定義正確的max_queue_size?
另外,有沒有一種方法可以衡量天氣發生器(生產者)或GPU培訓(消費者)的瓶頸? 我使用verbose = 1來打印狀態欄,以及單個線程生成器產生多少行。現在,它總是喜歡:
行數的產量=(max_queue_size +步數已處理)的batch_size *
所以我真的不能告訴如果發電機太慢喂在數據中或GPU訓練是瓶頸的時候,似乎稍後隊列總是滿員,但我不確定,任何洞察力都非常感謝。謝謝!
Keras建議您使用'Sequence'此:https://keras.io/utils/ –
還是提到[這裏](https://開頭stanford.edu/~shervine/blog/keras-generator-multiprocessing.html),使用一個簡單的鎖定機制使迭代器/生成器線程安全 – scarecrow
感謝Daniel,再次:)我沒有發現除https以外的太多示例: //gist.github.com/alxndrkalinin/6cc4228e9178ec4af7b2696a0d1ad5a1,會試試看。在我使用model.fit_generator()時,我注意到,在第二個時期,由於已經完成了半個步驟,準確度開始下降,它一直下降得很厲害,並且從未再次上升。你能否對這種情況有所瞭解?這是否在同一個時代過度適應?您能不能請我糾正我對Queue,多處理工作者和吞吐量瓶頸測量的理解? –