2016-11-28 181 views
12

我使用Python和Keras(目前使用Theano後端,但我沒有切換問題)。我有一個神經網絡,可以並行加載和處理多個信息源。目前,我一直在單獨的進程中運行每個進程,並從文件加載自己的網絡副本。這看起來像是浪費了RAM,所以我認爲讓所有線程使用一個網絡實例的單個多線程進程會更高效。但是,我想知道Keras是否與後端線程安全。如果我在不同的線程同時在兩個不同的輸入上運行.predict(x),我是否會遇到競爭條件或其他問題?Keras線程安全嗎?

謝謝

回答

6

是的,Keras是線程安全的,如果你稍微注意它的話。

事實上,在強化學習中,有一種算法叫Asynchronous Advantage Actor Critics (A3C),其中每個代理都依賴於同一個神經網絡來告訴他們他們應該在給定狀態下應該做什麼。換句話說,每個線程同時調用model.predict,就像你的問題一樣。 Keras的一個示例實現是here

你應該,但是,格外注意這行,如果你看着代碼: model._make_predict_function() # have to initialize before threading

這是從來沒有在Keras文檔提到,但它的必要,使其同時工作。簡而言之,_make_predict_function是編譯predict函數的函數。在多線程設置中,您必須手動調用此函數以提前編譯predict,否則在第一次運行函數之前不會編譯predict函數,這在多個線程立即調用時會出現問題。你可以看到一個詳細的解釋here

到目前爲止,我還沒有遇到過在Keras中多線程的任何其他問題。