2017-10-15 152 views
1

我正在嘗試建立與Tensorflow服務會話的通信。 以下代碼有效,但速度有點慢。有沒有辦法改進它?我懷疑這個問題出現在第四行 - 輸出結果是float_val元素的列表,我需要將它們轉換爲浮點數組並重新設置它們。Tensorflow服務響應

有沒有辦法讓在正確的形狀服務器輸出? 我已經定義了輸出簽名如下(我認爲是正確的)。

prediction_channel, request_form = setup_channel(args.server)  
request_form.inputs['images'].CopyFrom(
       tf.contrib.util.make_tensor_proto(img_transformed, shape=list(img_transformed.shape))) 
output = prediction_channel.Predict.future(request_form, 5.0) 
output = np.array(output.result().outputs['scores'].float_val).reshape(1, 16, 64, 64) 

第一行通過使用函數

打開到服務器的通道
def setup_channel(hostport): 
    host, port = hostport.split(':') 
    channel = implementations.insecure_channel(host, int(port)) 
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) 
    request = predict_pb2.PredictRequest() 
    request.model_spec.name = 'hg' 
    request.model_spec.signature_name = 'predict_images' 
    return stub, request 

輸出簽名是:

tensor_info_x = tf.saved_model.utils.build_tensor_info(model.input_tensor) 
tensor_info_y = tf.saved_model.utils.build_tensor_info(model.predict) 

prediction_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
     inputs={'images': tensor_info_x}, 
     outputs={'scores': tensor_info_y}, 
     method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))** 

和模型預測具有(1,16的形狀,64,64)。

回答

0

我不知道你是如何處理你Predict.future(request_form, 5.0),但同樣應該適用於同步響應處理; TF提供了一個實用功能make_ndarray

res = stub.Predict(request, timeout).outputs[tensor_name] 
arr = tf.make_ndarray(res) 

arr將是正確的DIMS的NP陣列。

tensor_name是在您的簽名定義例如名稱

tf.saved_model.signature_def_utils.build_signature_def(
    inputs={'images': inp_tensor_info}, 
    outputs={'scores': out_tensor_info}, 
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME 
) 

需要

res = stub.Predict(request, timeout).outputs['scores'] 
arr = tf.make_ndarray(res)