2016-08-22 141 views
1

T.meanthis example中有什麼意義?如果實現是矢量化的,我認爲T.mean是有意義的。這裏的輸入xytrain(x, y)是標量,而cost只發現單個輸入的平方誤差,並迭代數據。theano中的線性迴歸

cost = T.mean(T.sqr(y - Y)) 
gradient = T.grad(cost=cost, wrt=w) 
updates = [[w, w - gradient * 0.01]] 

train = theano.function(inputs=[X, Y], outputs=cost, updates=updates, allow_input_downcast=True) 

for i in range(100): 
    for x, y in zip(trX, trY): 
     train(x, y) 

print w.get_value() 

刪除T.mean對輸出模式沒有影響。

回答

1

你是對的,T.mean在這裏沒有意義。成本函數一次對單個訓練樣本進行操作,所以「均方誤差」實際上只是樣本的平方誤差。

本示例通過stochastic gradient descent(一種用於在線優化的算法)實現線性迴歸。 SGD像樣本一樣逐個迭代樣本。但是,在更復雜的情況下,數據集通常是processed in mini-batches,這會提供更好的性能和收斂性能。

我認爲T.mean留在這個例子中作爲小批量梯度下降的人工產物,或者更明確地說成本函數是MSE。