2017-06-03 46 views
0

我想實現一個KNN 1D估計:有超過一個元素的數組的Pyplot真值是不確定

# nearest neighbors estimate 
def nearest_n(x, k, data): 
    # Order dataset 
    #data = np.sort(data, kind='mergesort') 
    nnb = [] 
    # iterate over all data and get k nearest neighbours around x 
    for n in data: 
     if nnb.__len__()<k: 
      nnb.append(n) 
     else: 
      for nb in np.arange(0,k): 
       if np.abs(x-n) < np.abs(x-nnb[nb]): 
        nnb[nb] = n 
        break 

    nnb = np.array(nnb) 
    # get volume(distance) v of k nearest neighbours around x 
    v = nnb.max() - nnb.min() 
    v = k/(data.__len__()*v) 

    return v 

interval = np.arange(-4.0, 8.0, 0.1) 
plt.figure() 
for k in (2,8,35): 
    plt.plot(interval, nearest_n(interval, k,train_data), label=str(o)) 
plt.legend() 
plt.show() 

會拋出:

File "x", line 55, in nearest_n 
    if np.abs(x-n) < np.abs(x-nnb[nb]): 
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 

我知道錯誤來自數組輸入在plot()中,但我不知道如何在運算符中避免這種情況>/==/<

'data'來自包含浮點數的1D txt文件。

我嘗試使用矢量化:

nearest_n = np.vectorize(nearest_n) 

導致:

line 50, in nearest_n 
    for n in data: 
TypeError: 'numpy.float64' object is not iterable 

下面是一個例子,讓我們說:

data = [0.5,1.7,2.3,1.2,0.2,2.2] 
k = 2 

nearest_n(1.5)應該然後導致

nbb=[1.2,1.7] 
v = 0.5 

並返回2 /(6 * 0.5)= 2/3

該函數運行例如neares_n(2.0,4,數據),並給出0.0741586011463

+1

你能否包括預期的輸出(如果你必須手工完成,你可能需要使用較小的輸入) 。 :) – MSeifert

+0

輸出將是3個不同的概率密度圖(k = 2,8,35),s.th.來自數組[-4,8]的每個值將映射到概率[0,1] –

+0

不,我的意思是調用'nearest_n'的字面結果。例如,'nearest_n(np.arange(-4.0,8.0,0.1),2,np.array([1,2,3]))''應該返回什麼?我已經或多或少地選擇了這些值,如果需要的話插入更合適的值(如果沒有參考實現,則更容易手動計算)。 – MSeifert

回答

0

你在np.arange(-4, 8, .01)傳遞作爲x ,這是一組值。所以x - n是一個長度與x相同的數組,在這種情況下是120個元素,因爲減去一個數組和一個標量確實是逐元素減法。與nnb[nb]一樣。因此,比較的結果是一個長度爲120的數組,其布爾值取決於np.abs(x-n)的每個元素是否小於np.abs(x-nnb[nb])的對應元素。這不能直接用作條件,你需要將這些值合併爲一個布爾值(使用all()any()或者只是重新考慮代碼)。

+0

嗨,感謝我的回答,請看我的。這只是我期望pyplot工作有點不同 –

0
plt.figure() 
X = np.arange(-4.0,8.0,0.1) 
for k in [2,8,35]: 
    Y = [] 
    for n in X: 
     Y.append(nearest_n(n,k,train_data)) 
    plt.plot(X,Y,label=str(k)) 
plt.show() 

工作正常。我認爲pyplot.plot會爲我做這件事情,但我想它不會...

+0

這不是'pyplot'的問題,我不知道你爲什麼認爲它可能是?你寫了'nearest_n'來爲'x'參數取一個標量,所以如果不重寫你的代碼就不能傳入一個向量。在這裏,你正在遍歷一個向量,並且每次都將一個標量傳遞給你的函數。 – spruceb

+0

我以爲pyplot會像這樣處理矢量輸入,但我錯了 –

+0

我只是想澄清一下,因爲我不確定你是否理解了問題的根源。這個錯誤並沒有出現在plt中。plot'功能,是不是因爲你的投入'pyplot',誤差在'nearest_n'拋出是由於傳遞給函數的參數。 – spruceb

相關問題