2017-08-02 73 views
1

我一直在努力實現的Python生成到使用Keras.js圖書館網站基本Keras模型。現在,我的模型中訓練,並出口到model.jsonmodel_weights.bufmodel_metadata.json文件。現在,我基本上從github頁面複製並粘貼測試代碼,以查看模型是否會在瀏覽器中加載,但不幸的是我收到錯誤。這是測試代碼。 (編輯:我修正了一些錯誤,請參閱下面其餘的)實施Keras型號爲網站與Keras.js

var model = new KerasJS.Model({ 
    filepaths: { 
     model: 'dist/model.json', 
     weights: 'dist/model_weights.buf', 
     metadata: 'dist/model_metadata.json' 
    }, 
    gpu: true 
}); 

    model.ready() 
    .then(function() { 
    console.log("1"); 
    // input data object keyed by names of the input layers 
    // or `input` for Sequential models 
    // values are the flattened Float32Array data 
    // (input tensor shapes are specified in the model config) 
    var inputData = { 
     'input_1': new Float32Array(data) 
    }; 
    console.log("2 " + inputData); 
    // make predictions 
    return model.predict(inputData); 
    }) 
    .then(function(outputData) { 
    // outputData is an object keyed by names of the output layers 
    // or `output` for Sequential models 
    // e.g., 
    // outputData['fc1000'] 
    console.log("3 " + outputData); 
    }) 
    .catch(function(err) { 
    console.log(err); 
    // handle error 
    }); 

編輯:所以我改變了我的計劃圍繞一點與JS 5對應的(這是我的一個愚蠢的錯誤),並現在我遇到了一個不同的錯誤。該錯誤被捕獲並記錄。我得到的錯誤是:Error: predict() must take an object where the keys are the named inputs of the model: input.我相信這個問題是因爲我data變量是不正確的格式。我想,如果我的模型參加了號的28x28陣列,然後data也應該是一個28x28陣列,以便能夠正確地「預測」正確的輸出。但是,我相信我錯過了一些東西,這就是錯誤被拋出的原因。 This問題與我的非常相似,但是它在python中而不是JS。再次,任何幫助將不勝感激。

回答

0

好了,我想通了,爲什麼這是怎麼回事。有兩個問題。首先,data數組需要變平,所以我編寫了一個快速函數來獲取2D輸入並將其「變平」爲一個長度爲784的1D數組。然後,因爲我使用了Sequential模型,數據的鍵名不應該是'input_1',而只是'input'。這擺脫了所有的錯誤。

現在,要獲取輸出信息,我們可以將它存儲在如下所示的數組中:var out = outputData['output']。因爲我用MNIST數據集,out是長度爲10的一維陣列,其包含每個數字是所述用戶編寫的位的概率。從那裏,你可以簡單地找到具有最高概率的數字,並將其用作模型的預測。