python - tensorflow.train.import_meta_graph does not work? -


i try save , restore graph, simplest example not work expected (this done using version 0.9.0 or 0.10.0 on linux 64 without cuda using python 2.7 or 3.5.2)

first save graph this:

import tensorflow tf v1 = tf.placeholder('float32')  v2 = tf.placeholder('float32') v3 = tf.mul(v1,v2) c1 = tf.constant(22.0) v4 = tf.add(v3,c1) sess = tf.session() result = sess.run(v4,feed_dict={v1:12.0, v2:3.3}) g1 = tf.train.export_meta_graph("file") ## alternately tried: ## g1 = tf.train.export_meta_graph("file",collection_list=["v4"]) 

this creates file "file" non-empty , sets g1 looks proper graph definition.

then try restore graph:

import tensorflow tf g=tf.train.import_meta_graph("file") 

this works without error, not return @ all.

can provide necessary code save graph "v4" , restore running in new session produce same result?

to reuse metagraphdef, need record names of interesting tensors in original graph. example, in first program, set explicit name argument in definition of v1, v2 , v4:

v1 = tf.placeholder(tf.float32, name="v1") v2 = tf.placeholder(tf.float32, name="v2") # ... v4 = tf.add(v3, c1, name="v4") 

then, can use string names of tensors in original graph in call sess.run(). example, following snippet should work:

import tensorflow tf _ = tf.train.import_meta_graph("./file")  sess = tf.session() result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) 

alternatively, can use tf.get_default_graph().get_tensor_by_name() tf.tensor objects tensors of interest, can pass sess.run():

import tensorflow tf _ = tf.train.import_meta_graph("./file") g = tf.get_default_graph()  v1 = g.get_tensor_by_name("v1:0") v2 = g.get_tensor_by_name("v2:0") v4 = g.get_tensor_by_name("v4:0")  sess = tf.session() result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3}) 

update: based on discussion in comments, here complete example saving , loading, including saving variable contents. illustrates saving of variable doubling value of variable vx in separate operation.

saving:

import tensorflow tf v1 = tf.placeholder(tf.float32, name="v1")  v2 = tf.placeholder(tf.float32, name="v2") v3 = tf.mul(v1, v2) vx = tf.variable(10.0, name="vx") v4 = tf.add(v3, vx, name="v4") saver = tf.train.saver([vx]) sess = tf.session() sess.run(tf.initialize_all_variables()) sess.run(vx.assign(tf.add(vx, vx))) result = sess.run(v4, feed_dict={v1:12.0, v2:3.3}) print(result) saver.save(sess, "./model_ex1") 

restoring:

import tensorflow tf saver = tf.train.import_meta_graph("./model_ex1.meta") sess = tf.session() saver.restore(sess, "./model_ex1") result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3}) print(result) 

the bottom line that, in order make use of saved model, must remember names of @ least of nodes (e.g. training op, input placeholder, evaluation tensor, etc.). metagraphdef stores list of variables contained in model, , helps restore these checkpoint, required reconstruct tensors/operations used in training/evaluating model yourself.


Comments

Popular posts from this blog

Spring Boot + JPA + Hibernate: Unable to locate persister -

go - Golang: panic: runtime error: invalid memory address or nil pointer dereference using bufio.Scanner -

c - double free or corruption (fasttop) -