2017-10-05 61 views
1

我已經編寫了一些初學者代碼來使用正規方程計算簡單線性模型的係數。Python/Numpy中的正常方程實現

# Modules 
import numpy as np 

# Loading data set 
X, y = np.loadtxt('ex1data3.txt', delimiter=',', unpack=True) 

data = np.genfromtxt('ex1data3.txt', delimiter=',') 

def normalEquation(X, y): 
    m = int(np.size(data[:, 1])) 

    # This is the feature/parameter (2x2) vector that will 
    # contain my minimized values 
    theta = [] 

    # I create a bias_vector to add to my newly created X vector 
    bias_vector = np.ones((m, 1)) 

    # I need to reshape my original X(m,) vector so that I can 
    # manipulate it with my bias_vector; they need to share the same 
    # dimensions. 
    X = np.reshape(X, (m, 1)) 

    # I combine these two vectors together to get a (m, 2) matrix 
    X = np.append(bias_vector, X, axis=1) 

    # Normal Equation: 
    # theta = inv(X^T * X) * X^T * y 

    # For convenience I create a new, tranposed X matrix 
    X_transpose = np.transpose(X) 

    # Calculating theta 
    theta = np.linalg.inv(X_transpose.dot(X)) 
    theta = theta.dot(X_transpose) 
    theta = theta.dot(y) 

    return theta 

p = normalEquation(X, y) 

print(p) 

使用小數據集在這裏找到:

http://www.lauradhamilton.com/tutorial-linear-regression-with-octave

我取得共同efficients:[-0.34390603; 0.2124426]使用上面的代碼而不是:[24.9660; 3.3058]。任何人都可以幫助澄清我哪裏錯了?

+0

你有你的周圍,從例子中的錯路X和Y!如果我扭轉他們,我會得到你建議的答案 – jeremycg

回答

1

您的實施是正確的。你只換了Xy(仔細看他們如何定義xy),這就是爲什麼你會得到不同的結果。

呼叫normalEquation(y, X)給出[ 24.96601443 3.30576144],因爲它應該。

+0

哦,恥辱!謝謝你們的迴應。 – PS94

+1

感謝您的建議@Maxim – PS94

0

您可以實現正常的公式如下圖所示:

import numpy as np 

X = 2 * np.random.rand(100, 1) 
y = 4 + 3 * X + np.random.randn(100, 1) 

X_b = np.c_[np.ones((100, 1)), X] # add x0 = 1 to each instance 
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y) 

X_new = np.array([[0], [2]]) 
X_new_b = np.c_[np.ones((2, 1)), X_new] # add x0 = 1 to each instance 
y_predict = X_new_b.dot(theta_best) 
y_predict 
+0

您也可以使用['add_constant'](http://www.statsmodels.org/dev/generated/statsmodels.tools.tools.add_constant.html)。 –