Tensorflow + Mnist (两层CNN,两层全连接)

来源:互联网 发布:看门狗2低配置怎么优化 编辑:程序博客网 时间:2024/06/08 15:49

前面几天中断了好几天,装了个linux,搭建了一下深度学习环境。入坑tensorflow,算是目前相当方便的一个平台了。环境的搭建我有单独写了个博客。
我搭建的环境
ubuntu16.04LTS + python3.6+tensorflow1.2
我的硬件环境:
i7-4720HQ @2.60ghz*8 + 950m
直接上手tensorflow的入门教程,mnist手写字符的识别,tensorflow的官方文档写了一个手写字符识别的入门CNN网络,但是没有画出网络结构,相信对于初学者还是有点难以理解的。
我这里画了一个草图
这里写图片描述
总的来说就是两层卷积(第一层包括一个卷积(32个5×5的kernel)+一个池化,第二层包括一个卷积(64个5×5的kernel)+一个池化)两层全连接,前面3层的激活函数都是采用了relu:max(0,x),最后一层用softmax输出10类目标

#!/usr/bin/env python3# -*- coding: utf-8 -*-"""Created on Sun Jun 25 11:59:53 2017@author: matthew"""import input_dataimport tensorflow as tfmnist = input_data.read_data_sets("MNIST_data/", one_hot=True)sess = tf.InteractiveSession()# build softmaxx = tf.placeholder("float",shape = [None,784])y_ = tf.placeholder("float",shape = [None,10])#initialize weights and biasdef weight_variable(shape):    initial = tf.truncated_normal(shape,stddev = 0.1)    return tf.Variable(initial)def bias_variable(shape):    initial = tf.constant(0.1,shape = shape)    return tf.Variable(initial)#conv and poolingdef conv2d(x,w):    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding = 'SAME')def max_pool_2x2(x):    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding = 'SAME')#first layer,a conv layer and a pooling layer#conv layer:[5,5,1,32](the size of kernel)w_conv1 = weight_variable([5,5,1,32])b_conv1 = bias_variable([32])x_image = tf.reshape(x, [-1,28,28,1])#-1 denotes orignal sizeh_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)h_pool1 = max_pool_2x2(h_conv1)#second layer,a conv layer and a pooling layer#kernal size[5,5,32,64]w_conv2 = weight_variable([5,5,32,64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)h_pool2 = max_pool_2x2(h_conv2)#third layer,fclayer#now after two pooling, the size of image is 7*7w_fc1 = weight_variable([7*7*64,1024])b_fc1 = bias_variable([1024])#mat2vech_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)#add drop to fc layerkeep_prob = tf.placeholder("float")h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)w_fc2 = weight_variable([1024, 10])b_fc2 = bias_variable([10])y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)#cost funccross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))sess.run(tf.global_variables_initializer())for i in range(20000):    batch = mnist.train.next_batch(50)    if i%100 == 0:        train_accuracy = accuracy.eval(feed_dict = {x:batch[0],y_:batch[1],keep_prob:1.0})        print ("step %d,training accuracy %g"%(i,train_accuracy))    train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})print ("test accuracy %g"%accuracy.eval(feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

这是tensorflow官方提供的mnist数据下载文件

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# =============================================================================="""Functions for downloading and reading MNIST data."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange  # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

值得一提的是,drop_out技巧,用在全连接层,随机的减去全连接层当中的一些连接,加强网络范化能力,防止过拟合。
在我的电脑上跑了快一个小时,最终的范化误差再0.993。

原创粉丝点击