2017-06-13 53 views
0

我目前正在學習神經網絡背後的理論,並且我想學習如何對這些模型進行編碼。所以我開始看TensorFlow。在Python中使用TensorFlow的XOR Neural Netowrk

我發現了一個非常有趣的應用程序,我想編程,但是我目前無法使其工作,並且我不知道爲什麼!

的例子來自Deep Learning, Goodfellow et al 2016第171 - 177

import tensorflow as tf 

T = 1. 
F = 0. 
train_in = [ 
    [T, T], 
    [T, F], 
    [F, T], 
    [F, F], 
] 
train_out = [ 
    [F], 
    [T], 
    [T], 
    [F], 
] 
w1 = tf.Variable(tf.random_normal([2, 2])) 
b1 = tf.Variable(tf.zeros([2])) 

w2 = tf.Variable(tf.random_normal([2, 1])) 
b2 = tf.Variable(tf.zeros([1])) 

out1 = tf.nn.relu(tf.matmul(train_in, w1) + b1) 
out2 = tf.nn.relu(tf.matmul(out1, w2) + b2) 

error = tf.subtract(train_out, out2) 
mse = tf.reduce_mean(tf.square(error)) 

train = tf.train.GradientDescentOptimizer(0.01).minimize(mse) 

sess = tf.Session() 
tf.global_variables_initializer() 

err = 1.0 
target = 0.01 
epoch = 0 
max_epochs = 1000 

while err > target and epoch < max_epochs: 
    epoch += 1 
    err, _ = sess.run([mse, train]) 

print("epoch:", epoch, "mse:", err) 
print("result: ", out2) 

運行的代碼時,我得到Pycharm以下錯誤信息:Screenshot

回答

0

爲了運行初始化運,你應寫:

sess.run(tf.global_variables_initializer()) 

而不是:

tf.global_variables_initializer() 

這裏是一個工作版本:

import tensorflow as tf 

T = 1. 
F = 0. 
train_in = [ 
    [T, T], 
    [T, F], 
    [F, T], 
    [F, F], 
] 
train_out = [ 
    [F], 
    [T], 
    [T], 
    [F], 
] 
w1 = tf.Variable(tf.random_normal([2, 2])) 
b1 = tf.Variable(tf.zeros([2])) 

w2 = tf.Variable(tf.random_normal([2, 1])) 
b2 = tf.Variable(tf.zeros([1])) 

out1 = tf.nn.relu(tf.matmul(train_in, w1) + b1) 
out2 = tf.nn.relu(tf.matmul(out1, w2) + b2) 

error = tf.subtract(train_out, out2) 
mse = tf.reduce_mean(tf.square(error)) 

train = tf.train.GradientDescentOptimizer(0.01).minimize(mse) 

sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 

err = 1.0 
target = 0.01 
epoch = 0 
max_epochs = 1000 

while err > target and epoch < max_epochs: 
    epoch += 1 
    err, _ = sess.run([mse, train]) 

print("epoch:", epoch, "mse:", err) 
print("result: ", out2)