2017-09-13 144 views
0

我目前正在學習梯度下降,所以我寫了一段使用線性迴歸梯度下降的代碼。然而,我得到的路線並不是最好的路線。我用梯度下降和最小平方誤差迴歸計算了線性迴歸的誤差。無論使用什麼數據,最小平方誤差總會給我一個更低的誤差。我決定看看這兩個斜坡,並且攔截了兩個。使用梯度下降的y截距總是非常接近於零,就好像它沒有正確改變。我覺得這很奇怪,我不知道發生了什麼。我以某種方式不正確地實現漸變下降?Y截距不會改變線性迴歸梯度下降

import matplotlib.pyplot as plt 
datax=[] 
datay=[] 
def gradient(b_current,m_current,learningRate): 
    bgradient=0 
    mgradient=0 
    N=float(len(datax)) 
    for i in range(0,len(datax)): 
     bgradient+= (-2/N)*(datay[i]-((m_current*datax[i])+b_current)) 
     mgradient+= (-2/N)*datax[i]*(datay[i]-((m_current*datax[i])+b_current)) 
    newb=b_current-(bgradient*learningRate) 
    newm=m_current-(mgradient*learningRate) 
    return newm,newb 
def basic_linear_regression(x, y): 
    # Basic computations to save a little time. 
    length = len(x) 
    sum_x = sum(x) 
    sum_y = sum(y) 

    # sigma x^2, and sigma xy respectively. 
    sum_x_squared = sum(map(lambda a: a * a, x)) 
    sum_of_products = sum([x[i] * y[i] for i in range(length)]) 

    # Magic formulae! 
    a = (sum_of_products - (sum_x * sum_y)/length)/(sum_x_squared - ((sum_x ** 2)/length)) 
    b = (sum_y - a * sum_x)/length 
    return a, b 

def error(m,b,datax,datay): 
    error=0 
    for i in range(0,len(datax)): 
     error+=(datay[i]-(m*datax[i]+b)) 
    return error/len(datax) 
def run(): 
    m=0 
    b=0 
    iterations=1000 
    learningRate=.00001 
    for i in range(0,iterations): 
     m,b=gradient(b,m,learningRate) 

    print(m,b) 
    c,d=basic_linear_regression(datax,datay) 
    print(c,d) 
    gradientdescent=error(m,b,datax,datay) 
    leastsquarederrors=error(c,d,datax,datay) 
    print(gradientdescent) 
    print(leastsquarederrors) 
    plt.scatter(datax,datay) 
    plt.plot([0,300],[b,300*m+b]) 
    plt.axis('equal') 
    plt.show() 

run() 

回答

0

我看過學習率有時會在0.01的範圍內。這可能是你學習率爲0.00001後需要超過1000次迭代的原因,除非你的數據集很小。學習率越低,收斂所需的迭代次數越多。

我注意到的另一件事是,你正在修復迭代的次數。您永遠無法知道您的成本函數是否將在第1000次迭代時處於/接近全局最小值。特別是如此低的學習速度,如果您需要超過1000次迭代,該怎麼辦?爲了解決這個問題 - 嘗試使用while循環並在此循環中添加一個計算成本函數(delta J)的差異並保持循環直到(delta J < Threshold),其中閾值通常保持非常低(範圍爲0.01或0.001)。然後,當你跳出while循環並且將其與從OLS方法中獲得的值進行比較時,請使用cost函數。