许多TensorFlow初学者想把自己训练的模型保存,并且还原继续训练或者用作测试。但是TensorFlow官网的介绍太不实用,网上的资料又不确定哪个是正确可行的。
今天David 9 就来带大家手把手入门亲测可用的TensorFlow保存还原模型的正确方式,使用的是网上最多的Saver的save和restore方法, 并且把关键点为大家指出。
今天介绍最为可行直接的方式来自这篇Stackoverflow:https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model 亲测可用:
保存模型:
import tensorflow as tf #Prepare to feed input, i.e. feed_dict and placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} #Define a test operation that we will restore w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) #Create a saver object which will save all the variables saver = tf.train.Saver() #Run the operation by feeding input print sess.run(w4,feed_dict) #Prints 24 which is sum of (w1+w2)*b1 #Now, save the graph saver.save(sess, 'my_test_model',global_step=1000)
必须强调的是:这里4,5,6,11行中的name=’w1′, name=’w2′, name=’bias’, name=’op_to_restore’ 千万不能省略,这是恢复还原模型的关键。
还原模型:
import tensorflow as tf sess=tf.Session() #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Access saved Variables directly print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print sess.run(op_to_restore,feed_dict) #This will print 60 which is calculated
还原当然是用restore方法,这里的18,19,23行就是刚才的name关键字指定的Tensor变量,必须找对才能进行还原恢复。
其他的关键在代码和注释中可以一眼看出, 这里不加赘述了。
参考文献:
- https://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model
- https://nathanbrixius.wordpress.com/2016/05/24/checkpointing-and-reusing-tensorflow-models/
- https://stackoverflow.com/questions/42685994/how-to-get-a-tensorflow-op-by-name
- http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
- https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops/saving_and_restoring_variables
本文章属于“David 9的博客”原创,如需转载,请联系微信: david9ml,或邮箱:)yanchao727@gmail.com
或直接扫二维码:
David 9
Latest posts by David 9 (see all)
- 修订特征已经变得切实可行, “特征矫正工程”是否会成为潮流? - 27 3 月, 2024
- 量子计算系列#2 : 量子机器学习与量子深度学习补充资料,QML,QeML,QaML - 29 2 月, 2024
- “现象意识”#2:用白盒的视角研究意识和大脑,会是什么景象?微意识,主体感,超心智,意识中层理论 - 16 2 月, 2024