[TensorFlow深度学习入门]实战八·简便方法实现TensorFlow模型参数保存与加载(pb方式)

[TensorFlow深度学习入门]实战八·简便方法实现TensorFlow模型参数保存与加载(pb方式)

在上篇博文中,我们探索了TensorFlow模型参数保存与加载实现方法采用的是保存ckpt的方式。这篇博文我们会使用保存为pd格式文件来实现。
首先,我会在上篇博文基础上,实现由ckpt文件如何转换为pb文件,再去探索如何在训练时直接保存pb文件,最后是如何利用pb文件复现网络与参数完成应用预测功能。

  • ckpt文件转换pd文件

ckpt2pd文件代码:

import tensorflow as tf
pd_dir = "././Saver/test1/pb_dir/MyModel.pb"
with tf.Session() as sess:    
    #加载运算图
    saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
    #加载参数
    saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
    graph = tf.get_default_graph()
    out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
    saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)
    print("saver path: ",saver_path)

运行结果:

saver path:  ././Saver/test1/pb_dir/MyModel.pb
  • 训练保存pd文件

train文件代码

import tensorflow as tf

pd_dir = "././Saver/test2/pb_dir/MyModel.pb"



def main():
    x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
    #x = tf.constant([[1,2]],dtype=tf.float32)
    w1 = tf.get_variable("w1",dtype=tf.float32,initializer=tf.truncated_normal([2, 1], stddev=0.1))
    b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1])) 

    y = tf.add(tf.matmul(x,w1),b1,name="out")
    
    with tf.Session() as sess:
        #获取计算图
        graph = tf.get_default_graph()
        #获取name和ops,这次代码并没有用到
        ret = graph.get_operations()
        r_names = []
        #获取name list
        for r in ret:
            r_names.append(r.name)

        srun = sess.run
        srun(tf.global_variables_initializer())
        print("y: ",srun(y,{x:[[1,2]]}))
        #存入输入与输出接口
        out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
        saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)

        
        print("saver path: ",saver_path)

if __name__ == "__main__":
    main()

运行结果:

y:  [[0.14729613]]
saver path:  ./././Saver/test2/pb_dir/MyModel.pb
  • pb文件复现网络与参数

restore文件代码

import tensorflow as tf
from saver1 import pd_dir

with tf.Session() as sess:
    #用上下文管理器打开pd文件    
    with open(pd_dir,"rb") as pd_flie:
        #获取图
        graph = tf.GraphDef()
        #获取参数
        graph.ParseFromString(pd_flie.read())
        #引入输入输出接口
        ins, outs = tf.import_graph_def(graph,return_elements=["in:0","out:0"])
        #进行预测
        print("y: ",sess.run(outs,{ins:[[1,2]]}))

运行结果:

y:  [[0.14729613]]
发布了168 篇原创文章 · 获赞 398 · 访问量 19万+
展开阅读全文

在Golang上进行Tensorflow部分运行(RNN状态)

05-21

<div class="post-text" itemprop="text"> <p>I have a GRU RNN text generation model that I imported as protobuf in Golang.</p> <pre><code>model, err := tf.LoadSavedModel("poetryModel", []string{"goTag"}, nil) </code></pre> <p>Similar to <a href="https://www.tensorflow.org/tutorials/sequences/text_generation#the_prediction_loop" rel="nofollow noreferrer">the code from this Tensorflow tutorial</a>, I am running a prediction loop:</p> <pre><code>for len(generated_text) < 1000 { result, err := model.Session.Run( map[tf.Output]*tf.Tensor{ model.Graph.Operation("inputLayer_input").Output(0): tensor, }, []tf.Output{ model.Graph.Operation("outputLayer/add").Output(0), }, nil, ) ...} </code></pre> <p>However, this implementation discards all intermediate states after every loop which results in bad generated text. I tried using <a href="https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#PartialRun" rel="nofollow noreferrer">Partial Run</a>, but it threw an error at the second Run:</p> <pre><code>pr, err := model.Session.NewPartialRun( []tf.Output{ model.Graph.Operation("inputLayer_input").Output(0), }, []tf.Output{ model.Graph.Operation("outputLayer/add").Output(0), }, []*tf.Operation{ model.Graph.Operation("outputLayer/add") }, ) if err != nil { panic(err) } ... result, err := pr.Run( map[tf.Output]*tf.Tensor{ model.Graph.Operation("inputLayer_input").Output(0): tensor, }, []tf.Output{ model.Graph.Operation("outputLayer/add").Output(0), }, nil, ) </code></pre> <p><code>Error running the session with input, err: Must run 'setup' before performing partial runs!</code></p> <p><a href="https://stackoverflow.com/questions/45142977/tensorflow-partial-run-must-run-setup-before-performing-partial-runs-desp">This question</a> is similar to this one, but in Python. Also, there is no documentation of a setup function in Go. I am new to working directly with the TF computation graph and Golang, so any help is appreciated.</p> </div> 问答

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: Age of Ai 设计师: meimeiellie

分享到微信朋友圈

×

扫一扫,手机浏览