tensorflow学习——tf.get_collection(), tf.identity()

来源:互联网 发布:怎样申请做淘宝模特 编辑:程序博客网 时间:2024/05/21 17:53

0、将神经网络生成pb文件,测试程序
以下是程序的关键代码,详细见连接

# 从训练好的ckpt中,导出pb文件import fully_conected as modelimport tensorflow as tfdef export_graph(model_name):  graph = tf.Graph()  with graph.as_default():    input_image = tf.placeholder(tf.float32, shape=[None,28*28], name='inputdata')# 需要重写一下网络    logits = model.inference(input_image)    y_conv = tf.nn.softmax(logits,name='outputdata')    restore_saver = tf.train.Saver()  with tf.Session(graph=graph) as sess:    sess.run(tf.global_variables_initializer())    latest_ckpt = tf.train.latest_checkpoint('log')    restore_saver.restore(sess, latest_ckpt)    output_graph_def = tf.graph_util.convert_variables_to_constants(        sess, graph.as_graph_def(), ['outputdata'])#    tf.train.write_graph(output_graph_def, 'log', model_name, as_text=False)    with tf.gfile.GFile('log/mnist.pb', "wb") as f:          f.write(output_graph_def.SerializeToString())  export_graph('mnist.pb')
# 测试调用保存的pb 文件from __future__ import absolute_import, unicode_literalsfrom datasets_mnist import read_data_setsimport tensorflow as tftrain,validation,test = read_data_sets("datasets/", one_hot=True)with tf.Graph().as_default():    output_graph_def = tf.GraphDef()    output_graph_path = 'log/mnist.pb'#    sess.graph.add_to_collection("input", mnist.test.images)    with open(output_graph_path, "rb") as f:        output_graph_def.ParseFromString(f.read())        tf.import_graph_def(output_graph_def, name="")    with tf.Session() as sess:        tf.initialize_all_variables().run()        input_x = sess.graph.get_tensor_by_name("inputdata:0")                output = sess.graph.get_tensor_by_name("outputdata:0")        y_conv_2 = sess.run(output,{input_x:test.images})        print( "y_conv_2", y_conv_2)        # Test trained model        #y__2 = tf.placeholder("float", [None, 10])        y__2 = test.labels        correct_prediction_2 = tf.equal(tf.argmax(y_conv_2, 1), tf.argmax(y__2, 1))        print ("correct_prediction_2", correct_prediction_2 )        accuracy_2 = tf.reduce_mean(tf.cast(correct_prediction_2, "float"))        print ("accuracy_2", accuracy_2)        print ("check accuracy %g" % accuracy_2.eval())

1、tf.get_collection获取训练变量

#    train_vars=tf.trainable_variables()#    g_vars=[var for var in train_vars if var.name.startswith('generator')]#    d_vars=[var for var in train_vars if var.name.startswith('discriminator')]    g_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')    d_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

2、 tf.identity()

import tensorflow as tfx = tf.Variable(1.0)x_plus_1 = tf.assign_add(x, 1)with tf.control_dependencies([x_plus_1]):    y = x    z=tf.identity(x,name='x')init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    for i in range(5):        print(sess.run(z))

输出是:2,3,4,5,6

import tensorflow as tfx = tf.Variable(1.0)x_plus_1 = tf.assign_add(x, 1)with tf.control_dependencies([x_plus_1]):    y = x    z=tf.identity(x,name='x')init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    for i in range(5):        print(sess.run(y))

输出是:1,1,1,1,1

阅读全文
0 1