tensorflow学习——tf.get_collection(), tf.identity()
来源:互联网 发布:怎样申请做淘宝模特 编辑:程序博客网 时间:2024/05/21 17:53
0、将神经网络生成pb文件,测试程序
以下是程序的关键代码,详细见连接
# 从训练好的ckpt中,导出pb文件import fully_conected as modelimport tensorflow as tfdef export_graph(model_name): graph = tf.Graph() with graph.as_default(): input_image = tf.placeholder(tf.float32, shape=[None,28*28], name='inputdata')# 需要重写一下网络 logits = model.inference(input_image) y_conv = tf.nn.softmax(logits,name='outputdata') restore_saver = tf.train.Saver() with tf.Session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) latest_ckpt = tf.train.latest_checkpoint('log') restore_saver.restore(sess, latest_ckpt) output_graph_def = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), ['outputdata'])# tf.train.write_graph(output_graph_def, 'log', model_name, as_text=False) with tf.gfile.GFile('log/mnist.pb', "wb") as f: f.write(output_graph_def.SerializeToString()) export_graph('mnist.pb')
# 测试调用保存的pb 文件from __future__ import absolute_import, unicode_literalsfrom datasets_mnist import read_data_setsimport tensorflow as tftrain,validation,test = read_data_sets("datasets/", one_hot=True)with tf.Graph().as_default(): output_graph_def = tf.GraphDef() output_graph_path = 'log/mnist.pb'# sess.graph.add_to_collection("input", mnist.test.images) with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: tf.initialize_all_variables().run() input_x = sess.graph.get_tensor_by_name("inputdata:0") output = sess.graph.get_tensor_by_name("outputdata:0") y_conv_2 = sess.run(output,{input_x:test.images}) print( "y_conv_2", y_conv_2) # Test trained model #y__2 = tf.placeholder("float", [None, 10]) y__2 = test.labels correct_prediction_2 = tf.equal(tf.argmax(y_conv_2, 1), tf.argmax(y__2, 1)) print ("correct_prediction_2", correct_prediction_2 ) accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float")) print ("accuracy_2", accuracy_2) print ("check accuracy %g" % accuracy_2.eval())
1、tf.get_collection获取训练变量
# train_vars=tf.trainable_variables()# g_vars=[var for var in train_vars if var.name.startswith('generator')]# d_vars=[var for var in train_vars if var.name.startswith('discriminator')] g_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') d_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
2、 tf.identity()
import tensorflow as tfx = tf.Variable(1.0)x_plus_1 = tf.assign_add(x, 1)with tf.control_dependencies([x_plus_1]): y = x z=tf.identity(x,name='x')init = tf.global_variables_initializer()with tf.Session() as sess: sess.run(init) for i in range(5): print(sess.run(z))
输出是:2,3,4,5,6
import tensorflow as tfx = tf.Variable(1.0)x_plus_1 = tf.assign_add(x, 1)with tf.control_dependencies([x_plus_1]): y = x z=tf.identity(x,name='x')init = tf.global_variables_initializer()with tf.Session() as sess: sess.run(init) for i in range(5): print(sess.run(y))
输出是:1,1,1,1,1
阅读全文
0 1
- tensorflow学习——tf.get_collection(), tf.identity()
- TensorFlow学习--tf.add_to_collection与tf.get_collection使用
- TensorFlow 学习(一)—— tf.get_variable() vs tf.Variable(),tf.name_scope() vs tf.variable_scope()
- tensorflow学习——tf.floor与tf.train.batch
- tensorflow学习——tf.layers.batch_normalization/tf.nn.batch_normalization/tf.contrib.layers.batch_norm
- tensorflow API学习——tf.strided_slice
- TensorFlow 学习(二)—— tf.Session() 与 tf.Session().run()
- TensorFlow 学习(三)—— Variables(tf.initialize_all_variables()/tf.global_variables_initializer())
- tensorflow学习——tf.train.Supervisor()与tf.train.saver()
- tf.add_to_collection,tf.get_collection和tf.add_n的用法
- tf.add_to_collection、tf.get_collection、tf.add_n用法浅析
- TensorFlow 辨异 —— tf.placeholder 与 tf.Variable
- TensorFlow 辨异 —— tf.placeholder 与 tf.Variable
- tensorflow——tf.one_hot以及tf.sparse_to_dense函数
- TensorFlow 辨异 —— tf.placeholder 与 tf.Variable
- TensorFlow 辨异 —— tf.placeholder 与 tf.Variable
- TensorFlow 学习(十三)—— tf.app.flags
- tf.control_dependencies()和tf.identity()
- Node.js实现GitHub第三方登录
- WebLogic安装
- c++中“箭头”和“点号”操作符的区别
- Sql Server 2005 32位+64位、企业版+标准版、CD+DVD 下载地址大全
- 剑指Offer—45—扑克牌顺子
- tensorflow学习——tf.get_collection(), tf.identity()
- windows下Mingw(GCC) 编译Berkeley db4.8.30.NC
- BZOJ4538:[Hnoi2016]网络 (整体二分+Lca+树状数组/线段树+路径交/树链剖分+Heap)
- 为什么L1稀疏L2平滑?
- MarkDown常用语法
- 文件上传示例(上传到amazon s3服务器)
- Android系统中如何添加权限-----以TP为例
- Sklearn库学习笔记1 Feature_Engineering之预处理篇
- 1---Python初体验之生成随机数组并写入文件