2017-10-21 127 views
9

我有一個基本Android TensorFlowInference示例,它可以在單線程中正常運行。在多核設備上運行TensorFlow

public class InferenceExample { 

    private static final String MODEL_FILE = "file:///android_asset/model.pb"; 
    private static final String INPUT_NODE = "intput_node0"; 
    private static final String OUTPUT_NODE = "output_node0"; 
    private static final int[] INPUT_SIZE = {1, 8000, 1}; 
    public static final int CHUNK_SIZE = 8000; 
    public static final int STRIDE = 4; 
    private static final int NUM_OUTPUT_STATES = 5; 

    private static TensorFlowInferenceInterface inferenceInterface; 

    public InferenceExample(final Context context) { 
     inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE); 
    } 

    public float[] run(float[] data) { 

     float[] res = new float[CHUNK_SIZE/STRIDE * NUM_OUTPUT_STATES]; 

     inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]); 
     inferenceInterface.run(new String[]{OUTPUT_NODE}); 
     inferenceInterface.fetch(OUTPUT_NODE, res); 

     return res; 
    } 
} 

的例子崩潰,各種異常,包括java.lang.ArrayIndexOutOfBoundsExceptionjava.lang.NullPointerExceptionThreadPool按照下面的例子,所以我想這不是線程安全運行時。

InferenceExample inference = new InferenceExample(context); 

ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);  
Collection<Future<?>> futures = new LinkedList<Future<?>>(); 

for (int i = 1; i <= 100; i++) { 
    Future<?> result = executor.submit(new Runnable() { 
     public void run() { 
      inference.call(randomData); 
     } 
    }); 
    futures.add(result); 
} 

for (Future<?> future:futures) { 
    try { future.get(); } 
    catch(ExecutionException | InterruptedException e) { 
     Log.e("TF", e.getMessage()); 
    } 
} 

是否有可能利用多核Android設備與TensorFlowInferenceInterface

回答

0

TensorFlowInferenceInterface類不是線程安全的(因爲它保持到feed調用之間的狀態,runfetch

但是,它是建立在TensorFlow的Java API,其中Session類的對象是頂部線程安全的。

所以,你可能想要直接使用底層的Java API,TensorFlowInferenceInterface的構造函數創建一個Session,並與來自AssetManagercode)加載的Graph對其進行設置。

希望有所幫助。

1

爲了使InferenceExample線程安全的,我從static改變了TensorFlowInferenceInterface並取得了run方法​​:

private TensorFlowInferenceInterface inferenceInterface; 

public InferenceExample(final Context context) { 
    inferenceInterface = new TensorFlowInferenceInterface(assets, model); 
} 

public synchronized float[] run(float[] data) { ... } 

然後我循環賽跨越numThreadsInterferenceExample實例的列表。

for (int i = 1; i <= 100; i++) { 
    final int id = i % numThreads; 
    Future<?> result = executor.submit(new Runnable() { 
     public void run() { 
      list.get(id).run(data); 
     } 
    }); 
    futures.add(result); 
} 

這確實增加性能8芯裝置上但是 此峯的2 numThreads並只顯示在〜Android Studio中監視器50%的CPU使用率。

+0

我強烈建議不要這種方法。當然,你已經做到了這樣,可以同時調用'run',但只有當你不改變輸入時(通過調用'TensorFlowInferenceInterface.feed()')纔有意義。 假設你想要你的線程提供不同的輸入,以便計算可以在它們上面運行。你提出的方法對此並不安全。 – ash

+0

爲什麼對於不同的輸入不安全?通過按照'id'順序將期貨存儲在循環中的細微變化,我將知道哪個輸入與哪個輸出匹配。 –

+0

噢,對不起,我誤讀了,並沒有注意到'feed()'和'fetch()'調用在你的同步'run()'內。所以我在上面的評論中誤會了。 但是,您的方法會限制並行性,因爲這實際上是串行化使用TensorFlow會話 - 一次只能有一個線程執行模型。 – ash