python - tensorflow.train.import_meta_graph does not work? -
i try save , restore graph, simplest example not work expected (this done using version 0.9.0 or 0.10.0 on linux 64 without cuda using python 2.7 or 3.5.2)
first save graph this:
import tensorflow tf v1 = tf.placeholder('float32') v2 = tf.placeholder('float32') v3 = tf.mul(v1,v2) c1 = tf.constant(22.0) v4 = tf.add(v3,c1) sess = tf.session() result = sess.run(v4,feed_dict={v1:12.0, v2:3.3}) g1 = tf.train.export_meta_graph("file") ## alternately tried: ## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])
this creates file "file" non-empty , sets g1 looks proper graph definition.
then try restore graph:
import tensorflow tf g=tf.train.import_meta_graph("file")
this works without error, not return @ all.
can provide necessary code save graph "v4" , restore running in new session produce same result?
to reuse metagraphdef
, need record names of interesting tensors in original graph. example, in first program, set explicit name
argument in definition of v1
, v2
, v4
:
v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") # ... v4 = tf.add(v3, c1, name="v4")
then, can use string names of tensors in original graph in call sess.run()
. example, following snippet should work:
import tensorflow tf _ = tf.train.import_meta_graph("./file") sess = tf.session() result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
alternatively, can use tf.get_default_graph().get_tensor_by_name()
tf.tensor
objects tensors of interest, can pass sess.run()
:
import tensorflow tf _ = tf.train.import_meta_graph("./file") g = tf.get_default_graph() v1 = g.get_tensor_by_name("v1:0") v2 = g.get_tensor_by_name("v2:0") v4 = g.get_tensor_by_name("v4:0") sess = tf.session() result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})
update: based on discussion in comments, here complete example saving , loading, including saving variable contents. illustrates saving of variable doubling value of variable vx
in separate operation.
saving:
import tensorflow tf v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") v3 = tf.mul(v1, v2) vx = tf.variable(10.0, name="vx") v4 = tf.add(v3, vx, name="v4") saver = tf.train.saver([vx]) sess = tf.session() sess.run(tf.initialize_all_variables()) sess.run(vx.assign(tf.add(vx, vx))) result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) print(result) saver.save(sess, "./model_ex1")
restoring:
import tensorflow tf saver = tf.train.import_meta_graph("./model_ex1.meta") sess = tf.session() saver.restore(sess, "./model_ex1") result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) print(result)
the bottom line that, in order make use of saved model, must remember names of @ least of nodes (e.g. training op, input placeholder, evaluation tensor, etc.). metagraphdef
stores list of variables contained in model, , helps restore these checkpoint, required reconstruct tensors/operations used in training/evaluating model yourself.
Comments
Post a Comment