python - How to copy parameters from global model to thread-specific model -
context:
i new tensorflow , i'm trying implement of algorithms in this paper require copying global shared model local thread-specific model.
question:
what best/correct way accomplish above task? i've provided dummy example of way doing below , error i'm getting. can please explain why error occurs?
import tensorflow tf import threading class examplemodel(object): def __init__(self, graph): graph.as_default(): self.w = tf.variable(tf.constant(1, shape=[1,2])) sess = tf.interactivesession() graph = tf.get_default_graph() global_network = examplemodel(graph) sess.run(tf.initialize_all_variables()) def example(i): global global_network, graph local_network = examplemodel(graph) sess.run(local_network.w.assign(global_network.w)) threads = [] in range(5): t = threading.thread(target=example, args=(i,)) threads.append(t) t in threads: t.start()
error:
exception in thread thread-3: traceback (most recent call last): file "/users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 801, in __bootstrap_inner self.run() file "/users/kennyhsu5/anaconda/lib/python2.7/threading.py", line 754, in run self.__target(*self.__args, **self.__kwargs) file "tmp.py", line 16, in example local_network = examplemodel(graph) file "tmp.py", line 7, in __init__ self.w = tf.variable(tf.constant(1, shape=[1,2])) file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 211, in __init__ dtype=dtype) file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/variables.py", line 319, in _init_from_args self._snapshot = array_ops.identity(self._variable, name="read") file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2976, in __exit__ self._graph._pop_control_dependencies_controller(self) file "/users/kennyhsu5/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2996, in _pop_control_dependencies_controller assert self._control_dependencies_stack[-1] controller assertionerror
about tf.graph class in tensorflow:
important note: class not thread-safe graph construction. operations should created single thread, or external synchronization must provided. unless otherwise specified, methods not thread-safe.
the self.w = ...
declaration , local_network.w.assign(...)
operation causing error.
i know kills multithreading implementation, can fix above code moving these declarations main thread. can use threads run operations prescribed. example:
import tensorflow tf import threading class examplemodel(object): def __init__(self, graph): graph.as_default(): self.w = tf.variable(tf.constant(1, shape=[1,2])) sess = tf.interactivesession() graph = tf.get_default_graph() global_network = examplemodel(graph) sess.run(tf.global_variables_initializer()) def example(sess, assign_w): sess.run(assign_w) threads = [] in range(5): local_network = examplemodel(graph) assign_w = local_network.w.assign(global_network.w) t = threading.thread(target=example, args=(sess, assign_w)) threads.append(t) t in threads: t.start()
i advice pass variables thread via args
parameter rather using global
.
finally, consider using global_variables_initializer
instead of deprecated initialize_all_variables
.
Comments
Post a Comment