2017-02-28 123 views
0

我正在使用tensorflow版本0.12.1,並遵循this doc如何完成這個非常簡單的分佈式培訓示例?

我想要做的是在每個工人中加1到count

我的目標是打印>1的結果,但我只得到1

import tensorflow as tf 

FLAGS = tf.app.flags.FLAGS 
tf.app.flags.DEFINE_string('job_name', '', '') 
tf.app.flags.DEFINE_string('ps_hosts', '','') 
tf.app.flags.DEFINE_string('worker_hosts', '','') 
tf.app.flags.DEFINE_integer('task_index', 0, '') 

ps_hosts = FLAGS.ps_hosts.split(',') 
worker_hosts = FLAGS.worker_hosts.split(',') 
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,'worker': worker_hosts}) 
server = tf.train.Server(
        {'ps': ps_hosts,'worker': worker_hosts}, 
        job_name=FLAGS.job_name, 
        task_index=FLAGS.task_index) 

if FLAGS.job_name == 'ps': 
    server.join() 

with tf.device(tf.train.replica_device_setter(
       worker_device="/job:worker/task:%d" % FLAGS.task_index, 
       cluster=cluster_spec)): 
    count = tf.Variable(0) 
    count = tf.add(count,tf.constant(1)) 
    init = tf.global_variables_initializer() 

sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), 
          logdir="./checkpoint/", 
          init_op=init, 
          summary_op=None, 
          saver=None, 
          global_step=None, 
          save_model_secs=60) 

with sv.managed_session(server.target) as sess: 
    sess.run(init) 
    step = 1 
    while step <= 999999999: 
     result = sess.run(count) 
     if step%10000 == 0: 
      print(result) 
     if result>=2: 
      print("!!!!!!!!") 
     step += 1 
    print("Finished!") 

sv.stop() 

回答

0

的問題實際上是獨立的分佈式執行的,並從這些兩行莖:

count = tf.Variable(0) 
    count = tf.add(count,tf.constant(1)) 

tf.add() op是純功能性的運算,其中,每次用其輸出創建一個新的張量它運行,而不是修改其輸入。如果你想增加價值,並增加橫跨工人可見,則必須使用tf.Variable.assign_add()方法來代替,如下:

count = tf.Variable(0) 
    increment_count = count.assign_add(1) 

然後調用sess.run(increment_count)你的訓練循環內遞增count變量的值。

相關問題