用卷积神经网络对mnist进行数字识别程序(tensorflow)

来源:互联网 发布:js 360加速球 编辑:程序博客网 时间:2024/06/16 03:22
#下载数据集from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data",one_hot = True)#引入tensorflowimport tensorflow as tf#建立session对象sess = tf.InteractiveSession()#占位符(图像和标签)x = tf.placeholder(tf.float32, shape=[None, 784])y_ = tf.placeholder(tf.float32, shape=[None, 10])#权重函数def weight_variable(shape):    initial = tf.truncated_normal(shape,stddev = 0.1)    return tf.Variable(initial)#偏置函数def bias_variable(shape):    initial = tf.constant(0.1,shape = shape)    return tf.Variable(initial)#卷积函数def conv2d(x,W):    return tf.nn.conv2d(x, W, strides = [1,1,1,1], padding = "SAME")#池化函数def max_pool_2x2(x):    return tf.nn.max_pool(x, ksize = [1,2,2,1], strides = [1,2,2,1], padding = "SAME")#第一层卷积W_conv1 = weight_variable([5,5,1,32])b_conv1 = bias_variable([32])x_image = tf.reshape(x, [-1,28,28,1])h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1)+b_conv1)h_pool1 = max_pool_2x2(h_conv1)#第二层卷积W_conv2 = weight_variable([5,5,32,64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2)+b_conv2)h_pool2 = max_pool_2x2(h_conv2)#全连接层W_fc1 = weight_variable([7*7*64,1024])b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1,7*7*64])h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1)+b_fc1)keep_prob = tf.placeholder(tf.float32)h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)#softmax层W_fc2 = weight_variable([1024,10])b_fc2 = bias_variable([10])y_conv = tf.matmul(h_fc1_drop, W_fc2)+b_fc2#训练和评价模型cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y_,logits = y_conv))train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_,1), tf.argmax(y_conv,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess.run(tf.global_variables_initializer())for i in range(20000):    batch = mnist.train.next_batch(50)    if i%100 == 0:        train_accuracy = accuracy.eval(feed_dict = {x:batch[0],y_:batch[1],keep_prob:1.0})        print("step %d,training accuracy %0.4f"%(i,train_accuracy))    train_step.run(feed_dict = {x:batch[0],y_:batch[1],keep_prob:0.5})print("test accuracy %0.4f"%accuracy.eval(feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

运行结果:

测试集准确率:99.2%

训练集准确率(部分,不完整)

step 0,training accuracy 0.1800step 100,training accuracy 0.8800step 200,training accuracy 0.8800step 300,training accuracy 0.9200step 400,training accuracy 0.9200step 500,training accuracy 0.9600step 600,training accuracy 0.9000step 700,training accuracy 1.0000step 800,training accuracy 0.9800step 900,training accuracy 0.9600step 1000,training accuracy 1.0000step 1100,training accuracy 0.9600step 1200,training accuracy 0.9800step 1300,training accuracy 0.9600step 1400,training accuracy 0.9400step 1500,training accuracy 0.9800step 1600,training accuracy 0.9800step 1700,training accuracy 1.0000step 1800,training accuracy 0.9600step 1900,training accuracy 0.9800step 2000,training accuracy 0.9600step 2100,training accuracy 0.9200step 2200,training accuracy 1.0000step 2300,training accuracy 1.0000step 2400,training accuracy 1.0000step 2500,training accuracy 1.0000step 2600,training accuracy 1.0000step 2700,training accuracy 0.9800step 2800,training accuracy 0.9600step 2900,training accuracy 1.0000step 3000,training accuracy 1.0000step 3100,training accuracy 0.9800step 3200,training accuracy 0.9800step 3300,training accuracy 0.9800step 3400,training accuracy 0.9400step 3500,training accuracy 0.9800step 3600,training accuracy 1.0000step 3700,training accuracy 1.0000step 3800,training accuracy 1.0000step 3900,training accuracy 0.9600step 4000,training accuracy 0.9400............
1 0
原创粉丝点击