2014-09-11 135 views
2

我一直在努力訓練下面的網絡,並獲得合適的權重,但它一直在運行。任何人都可以告訴我代碼中可能有什麼錯誤?這裏{8,1}是輸入,{-1}}預期使用一個符號函數輸出。單層感知器培訓?

import java.util.Arrays; 

public class ANN { 

    public static void main(String args[]) { 

     double threshold = 1.2; 
     double learningRate = 0.08; 

     // Init weights 

     double[] weights = { -1.4, 1.8 }; 

     int[][][] trainingData = { 
      {{8, 1}, {-1}}, 
      {{3, 2}, {-1}}, 
      {{6, 3}, {-1}}, 
      {{1, 4}, {-1}}, 
      {{9, 5}, {1}}, 
      {{5, 6}, {1}}, 
      {{2, 7}, {1}}, 
      {{4, 8}, {1}}, 
      {{7, 9}, {1}}, 
     }; 

     // Start training loop 
     while (true) { 
      int errorCount = 0; 
      // Loop over training data 
      for (int i = 0; i < trainingData.length; i++) { 
       System.out.println("Starting weights: " + Arrays.toString(weights)); 
       // Calculate weighted input 
       double weightedSum = 0; 
       for (int ii = 0; ii < trainingData[i][0].length; ii++) { 
        weightedSum += trainingData[i][0][ii] * weights[ii]; 
       } 

       // Calculate output 
       int output = 0; 
       if (threshold <= weightedSum) { 
        output = 1; 
       } 

       System.out.println("Target output: " + trainingData[i][1][0] 
         + ", " + "Actual Output: " + output); 

       // Calculate error 
       int error = trainingData[i][1][0] - output; 
       System.out.println("Error: " + error); 
       // Increase error count for incorrect output 
       if (error != 0) { 
        errorCount++; 
       } 

       // Update weights 
       for (int ii = 0; ii < trainingData[i][0].length; ii++) { 
        weights[ii] += learningRate * error 
          * trainingData[i][0][ii]; 
       } 

       System.out.println("New weights: " + Arrays.toString(weights)); 
       System.out.println(); 
      } 

      // If there are no errors, stop 
      if (errorCount == 0) { 
       System.out 
         .println("Final weights: " + Arrays.toString(weights)); 
       System.exit(0); 
      } 
     } 
    } 

} 

編輯:我認爲問題出現在計算輸出的代碼片段中。它應該翻轉,以便如果總和大於閾值,則輸出爲1,否則爲0。

// Calculate output 
       int output = 0; 
       if (weightedSum > threshold) { 
        output = 1; 
       } 

回答

1

我遇到你的代碼,並添加了一行之前的(ERRORCOUNT == 0)檢查:

System.out.println(errorCount); 

這看似6和7之間振盪,這意味着神經網絡總是無論訓練數量如何,都會生成對訓練數據的無效估計。如果訓練數據無法達到100%的正確率,那麼預計這種訓練將永遠持續。

希望這有助於!

1

您的錯誤可能是正面和負面的。在第一次運行中,錯誤是-1。因此,errorCount遞增,退出循環的代碼從不執行。

完整培訓的條件應該基於錯誤本身,而不是errorCount。當錯誤達到最低水平(您將根據您的輸入設置)時,培訓將被視爲已完成。