2017-10-05 203 views
0

我想在python中使用 numpy實現向量化邏輯迴歸。我的成本函數(CF)似乎工作正常。但是梯度計算有一個 問題。它返回3x100陣列,而它的 應該返回3x1。我認爲(hypo-y)部分存在問題。Python Numpy邏輯迴歸

def sigmoid(a): 
    return 1/(1+np.exp(-a))  

def CF(theta,X,y): 
    m=len(y) 
    hypo=sigmoid(np.matmul(X,theta)) 
    J=(-1./m)*((np.matmul(y.T,np.log(hypo)))+(np.matmul((1-y).T,np.log(1-hypo)))) 
    return(J) 

def gr(theta,X,y): 
    m=len(y) 
    hypo=sigmoid(np.matmul(X,theta)) 

    grad=(1/m)*(np.matmul(X.T,(hypo-y))) 

    return(grad) 

X是100x3 arrray,y是100X1和theta是一個3×1 arrray。看來兩個功能獨立工作,然而,這優化功能提供了一個錯誤:

optim = minimize(CF, theta, method='BFGS', jac=gr, args=(X,y)) 

The error: "ValueError: shapes (3,100) and (3,100) not aligned: 100 (dim 1) != 3 (dim 0)"

+1

請說明如何用示例輸入調用函數。我認爲這與形成的形狀有很大關係。 – kazemakase

+0

我的X輸入是100X3陣列,y輸入是100X1,θ輸入是3X1陣列。現在看起來兩個函數都單獨工作,但是這個優化函數給出了一個錯誤:optim = minimize(CF,theta,method ='BFGS',jac = gr,args =(X,y))錯誤:「ValueError:shapes(3,100 )和(3,100)不對齊:100(dim 1)!= 3(dim 0)「感謝您的關注! – efeatikkan

回答

0

I think there is a problem with the (hypo-y) part.

點上!

hypo是形狀(100,)y是形狀(100, 1)。在元素方面-操作中,根據numpy的broadcasting ruleshypo被廣播以形成(1, 100)。這導致(100, 100)陣列,這導致矩陣乘法導致(3, 100)陣列。

hypo = sigmoid(np.matmul(X, theta)).reshape(-1, 1) # -1 means automatic size on first dimension 

還有一個問題::scipy.optimize.minimize(我假定您使用)期望梯度爲形狀(k,)但陣列通過使hypo成相同的形狀y

修復此函數gr返回形狀爲(k, 1)的矢量。這是很容易解決:

return grad.reshape(-1) 

最終的功能變得

def gr(theta,X,y): 
    m=len(y) 
    hypo=sigmoid(np.matmul(X,theta)).reshape(-1, 1) 
    grad=(1/m)*(np.matmul(X.T,(hypo-y))) 
    return grad.reshape(-1) 

和玩具數據工程運行它(我沒有檢查數學或結果的合理性):

theta = np.reshape([1, 2, 3], 3, 1)  
X = np.random.randn(100, 3) 
y = np.round(np.random.rand(100, 1))  

optim = minimize(CF, theta, method='BFGS', jac=gr, args=(X,y)) 
print(optim) 
#  fun: 0.6830931976615066 
# hess_inv: array([[ 4.51307367, -0.13048255, 0.9400538 ], 
#  [-0.13048255, 3.53320257, 0.32364498], 
#  [ 0.9400538 , 0.32364498, 5.08740428]]) 
#  jac: array([ -9.20709950e-07, 3.34459058e-08, 2.21354905e-07]) 
# message: 'Optimization terminated successfully.' 
#  nfev: 15 
#  nit: 13 
#  njev: 15 
# status: 0 
# success: True 
#  x: array([-0.07794477, 0.14840167, 0.24572182]) 
+0

非常感謝,它現在正常工作! – efeatikkan