2016-10-03 181 views
0

使用Tensorflow時,第一步是構建數據圖並使用會話來運行它。而在我的練習中,如MNIST tutorial。它首先界定損失功能和優化,用下面的代碼(和MLP模型之前定義):Tensorflow:它如何訓練模型?

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) #define cross entropy error function 

loss = tf.reduce_mean(cross_entropy, name='xentropy_mean') #define loss 

optimizer = tf.train.GradientDescentOptimizer(learning_rate) #define optimizer 

global_step = tf.Variable(0, name='global_step', trainable=False) #learning rate 

train_op = optimizer.minimize(loss, global_step=global_step) #train operation in the graph 

培訓過程:

train_step =tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 


for i in range(1000): 
    batch_xs, batch_ys = mnist.train.next_batch(100) 
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 

那怎麼Tensorflow在這種情況下進行了培訓。但我的問題是,Tensorflow是如何知道需要訓練和更新的重量?我的意思是,在訓練代碼中,我們只通過輸出ycross_entropy,但是對於optimizerloss,我們沒有傳遞關於結構直接的任何信息。另外,我們使用字典將批量數據送到train_step,但是train_step沒有直接使用這些數據。 Tensorflow如何知道將這些數據用作輸入?

我的問題,我想這可能是所有這些變量或常量存儲在張量。像tf.matmul()這樣的操作應該是Tensorflow操作類的「子類」(我還沒有檢查代碼)。 Tensorflow可能有一些機制來識別張量之間的關係(tf.Variable(),tf.constant())和操作(tf.mul(),tf.div() ...)。我想,它可以檢查tf.xxxx()的超類,以確定它是否是張量或操作。這個假設提出了我的第二個問題:我應該儘可能使用Tensorflow的'tf.xxx'函數來確保tensorflow可以構建正確的數據流圖,即使有時它比普通的Python方法更復雜,或者Numpy支持的某些函數比Tensorflow?

我的最後一個問題是:Tensorflow和C++之間有什麼關係?我聽到有人說Tensorflow比普通的Python更快,因爲它使用C或C++作爲後端。有沒有將Tensorflow Python代碼轉換爲C/C++的轉換機制?

如果有人可以在使用Tensorflow編碼時分享一些調試習慣,我也會很優雅,因爲目前我只是設置了一些終端(Ubuntu)來測試我的代碼的每個部分/功能。

+0

您可以用C庫擴展python,這是一種可能的方式,只是C庫的Python API。 – Marcus

+0

@Marcus Yep,沒錯。我不知道Python版本Tensorflow的能力,是否比普通的純Python編碼Numpy或Scipy更快? –

回答

1

你做傳遞有關的結構信息Tensorflow當你定義你的損失:

loss = tf.reduce_mean(cross_entropy, name='xentropy_mean') 

注意與Tensorflow你建立操作的圖形,每一次你在代碼中使用的操作中的一個節點圖表。

當您定義您的loss時,您將傳遞存儲在cross_entropy中的操作,這取決於y_yy_是您輸入的佔位符,而yy = tf.nn.softmax(tf.matmul(x, W) + b)的結果。看看我要去哪裏?操作loss包含建立模型過程輸入所需的所有信息,因爲它取決於操作cross_entropy,這取決於y_y,這取決於輸入x和模型權重W

所以,當你調用

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) 

Tensorflow非常瞭解操作是否在運行train_step來計算,並確切地知道哪裏放的操作圖表您是路過feed_dict數據。

至於Tensorflow如何知道應該訓練哪些變量,答案很簡單。它在可訓練的操作圖中訓練任何tf.Variable()。請注意,在定義global_step時,如何設置trainable=False是因爲您不想計算該變量的梯度w.r.t。