python - How to copy parameters from global model to thread-specific model -


context:

i new tensorflow , i'm trying implement of algorithms in this paper require copying global shared model local thread-specific model.

question:

what best/correct way accomplish above task? i've provided dummy example of way doing below , error i'm getting. can please explain why error occurs?

import tensorflow tf import threading  class examplemodel(object):   def __init__(self, graph):     graph.as_default():       self.w = tf.variable(tf.constant(1, shape=[1,2]))  sess = tf.interactivesession() graph = tf.get_default_graph() global_network = examplemodel(graph) sess.run(tf.initialize_all_variables())  def example(i):   global global_network, graph   local_network = examplemodel(graph)   sess.run(local_network.w.assign(global_network.w))  threads = [] in range(5):   t = threading.thread(target=example, args=(i,))   threads.append(t)  t in threads:   t.start() 

error:

exception in thread thread-3: traceback (most recent call last):   file "/users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 801, in __bootstrap_inner     self.run()   file "/users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 754, in run     self.__target(*self.__args, **self.__kwargs)   file "tmp.py", line 16, in example     local_network = examplemodel(graph)   file "tmp.py", line 7, in __init__     self.w = tf.variable(tf.constant(1, shape=[1,2]))   file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 211, in __init__ dtype=dtype)   file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 319, in _init_from_args     self._snapshot = array_ops.identity(self._variable, name="read")   file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2976, in __exit__     self._graph._pop_control_dependencies_controller(self)   file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2996, in _pop_control_dependencies_controller     assert self._control_dependencies_stack[-1] controller assertionerror 

about tf.graph class in tensorflow:

important note: class not thread-safe graph construction. operations should created single thread, or external synchronization must provided. unless otherwise specified, methods not thread-safe.

the self.w = ... declaration , local_network.w.assign(...) operation causing error.

i know kills multithreading implementation, can fix above code moving these declarations main thread. can use threads run operations prescribed. example:

import tensorflow tf import threading  class examplemodel(object):   def __init__(self, graph):     graph.as_default():       self.w = tf.variable(tf.constant(1, shape=[1,2]))  sess = tf.interactivesession() graph = tf.get_default_graph() global_network = examplemodel(graph) sess.run(tf.global_variables_initializer())  def example(sess, assign_w):   sess.run(assign_w)  threads = [] in range(5):   local_network = examplemodel(graph)   assign_w = local_network.w.assign(global_network.w)   t = threading.thread(target=example, args=(sess, assign_w))   threads.append(t)  t in threads:   t.start() 

i advice pass variables thread via args parameter rather using global.

finally, consider using global_variables_initializer instead of deprecated initialize_all_variables.


Comments

Popular posts from this blog

amazon web services - S3 Pre-signed POST validate file type? -

c# - Check Keyboard Input Winforms -