经典手写数字mnist数据集识别

来源:互联网 发布:跟兄弟连学php 电子书 编辑:程序博客网 时间:2024/06/07 20:58

今天是我的第一篇博客,就从最基本的用神经网络识别手写数字mnist数据集开始。。。本博客资源来源于网络,为了提供给自己和刚开始接触机器学习和深度学习的同学参考一下,如有雷同请自行忽略。。。
以下三块程序是初学者可以学习用的,不包含图片预处理和可视化部分,采用CPU运算。
mnist_inference.py代码部分,主要定义了神经网络的结构参数和前向传播的过程。(先上传代码,后期会加上注释)

# -*- coding: utf-8 -*-"""Created on Mon Jul 10 11:36:35 2017@author: cxl"""import tensorflow as tfINPUT_NODE = 784OUTPUT_NODE = 10LAYER1_NODE = 500def get_weight_variable(shape,regularizer):    weights = tf.get_variable("weights",shape,        initializer = tf.truncated_normal_initializer(stddev=0.1))    if regularizer != None:        tf.add_to_collection('losses',regularizer(weights))    return weightsdef inference(input_tensor,regularizer):    with tf.variable_scope('layer1'):        weights =get_weight_variable([INPUT_NODE,LAYER1_NODE],regularizer)        biases = tf.get_variable("biases",[LAYER1_NODE],            initializer = tf.constant_initializer(0.0))        layer1 = tf.nn.relu(tf.matmul(input_tensor,weights)+biases)    with tf.variable_scope('layer2'):        weights = get_weight_variable(            [LAYER1_NODE,OUTPUT_NODE],regularizer)        biases = tf.get_variable(            "biases",[OUTPUT_NODE],initializer=tf.constant_initializer(0.0))        layer2 = tf.matmul(layer1,weights)+biases    return layer2

mnist_train.py代码部分,主要定义了神经网络的训练过程。

# -*- coding: utf-8 -*-"""Created on Mon Jul 10 15:45:22 2017@author: cxl"""import osimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_inferenceBATCH_SIZE = 100LEARNING_RATE_BASE = 0.8LEARNING_RATE_DECAY=0.99REGULARAZTION_RATE=0.0001TRAINING_STEPS = 30000MOVING_AVERAGE_DECAY=0.99MODEL_SAVE_PATH = "./path/to/model/"MODEL_NAME = "model.ckpt"def train(mnist):    x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],                     name='x-input')    y_=tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],                     name='y-input')    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)    y = mnist_inference.inference(x,regularizer)    global_step = tf.Variable(0,trainable=False)    variable_averages = tf.train.ExponentialMovingAverage(        MOVING_AVERAGE_DECAY,global_step)    variables_averages_op=variable_averages.apply(tf.trainable_variables())    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(        logits=y,labels=tf.argmax(y_,1))    cross_entropy_mean=tf.reduce_mean(cross_entropy)    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))    learning_rate = tf.train.exponential_decay(        LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,        LEARNING_RATE_DECAY)    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)    with tf.control_dependencies([train_step,variables_averages_op]):        train_op = tf.no_op(name = 'train')    saver = tf.train.Saver()    with tf.Session() as sess:        tf.initialize_all_variables().run()        for i in range(TRAINING_STEPS):            xs,ys = mnist.train.next_batch(BATCH_SIZE)            _,loss_value,step = sess.run([train_op,loss,global_step],                feed_dict={x:xs,y_:ys})            if i%1000 ==0:                print("After %d training step(s),loss on training"                      "batch is %g."%(step,loss_value))                saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),                      global_step = global_step)def main(argv = None):    mnist = input_data.read_data_sets("/tmp/data",one_hot=True)        train(mnist)if __name__=='__main__':    tf.app.run()

mnist_eval.py代码部分,主要定义了神经网络的测试过程。

# -*- coding: utf-8 -*-"""Created on Mon Jul 10 23:29:34 2017@author: cxl"""import timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_inferenceimport mnist_trainEVAL_INTERVAL_SECS = 10def evaluate(mnist):    #with tf.Graph().as_default() as g:    x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name = 'x-input')    y_= tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name = 'y-input')    validate_feed = {x:mnist.validation.images,y_:mnist.validation.labels}    y = mnist_inference.inference(x,None)    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))    variable_averages = tf.train.ExponentialMovingAverage(        mnist_train.MOVING_AVERAGE_DECAY)    variables_to_restore = variable_averages.variables_to_restore()    saver = tf.train.Saver(variables_to_restore)    while True:        with tf.Session() as sess:            ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)            if ckpt and ckpt.model_checkpoint_path:                saver.restore(sess,ckpt.model_checkpoint_path)                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]                accuracy_score = sess.run(accuracy,feed_dict = validate_feed)                print("After %s training step(s),validation"                      "accuracy = %g" % (global_step,accuracy_score))            else:                print("No checkpoint file found")                return        time.sleep(EVAL_INTERVAL_SECS)def main(argv = None):    mnist = input_data.read_data_sets("/tmp/data",one_hot = True)    evaluate(mnist)if __name__ == '__main__':    tf.app.run()

第一次写博客,以后会把我自己学习机器学习/深度学习的过程都写下来,供自己和有兴趣的没有基础的小伙伴们一起学习,一起进步,我以后也会不断提高自己的博客质量和代码水平的。。。

阅读全文
0 0