用TensorFlow构建基础的神经网络(二):LeNet-5

来源:互联网 发布:淘宝海外直邮是正品吗 编辑:程序博客网 时间:2024/06/06 01:31

参考文献:http://yann.lecun.com/exdb/lenet/index.html
源代码下载:http://pan.baidu.com/s/1cJEpum
作者:XJTU_Ironboy
时间:2017年8月

二、LeNet-5

1.介绍

  LeNet-5是一个经典的CNN网络模型,几乎所有讲CNN的资料都会提到该模型;该模型是为了识别手写字体和计算机打印字符而设计的,而且该模型确实在手写体识别领域非常成功。当年美国大多数银行就是用它来识别支票上面的手写数字的,能够达到这种商用的地步,它的准确性可想而知。毕竟目前学术界和工业界的结合是最受争议的。

2.结构

这里写图片描述

① 输入:32×32 通道数为1的图片
② 第一层(C1):卷积层
  输入图像大小:32×32×1
  卷积核的大小:5×5
  Strides: 1
  Padding: VALID
  卷积核的个数: 6
  激活函数: ReLU
  可训练参数: 156(5×5×6+6)
  输出图像大小:28×28×6
③ 第二层(S2):下采样层(又称池化层)
  输入图像大小:28×28×6
  池化的种类: 最大池化
  池化窗的大小:2×2
  Strides: 2
  Padding: VALID
  输出图像大小:14×14×6
④ 第三层(C3):卷积层
  输入图像大小:14×14×6
  卷积核的大小:5×5
  Strides: 1
  Padding: VALID
  卷积核的个数: 16
  激活函数: ReLU
  可训练参数: 416(5×5×16+16)
  输出图像大小:10×10×16
⑤ 第四层(S4):下采样层(又称池化层)
  输入图像大小:10×10×16
  池化的种类: 最大池化
  池化窗的大小:2×2
  Strides: 2
  Padding: VALID
  输出图像大小:5×5×16
⑥ 第五层(C5):卷积层 (注:由于该层卷积核的大小与输入图像相同,故也可认为是全连接层)
  输入图像大小:5×5×16
  卷积核的大小:5×5
  Strides: 1
  Padding: VALID
  卷积核的个数: 120
  激活函数: ReLU
  可训练参数: 3120(5×5×120+120)
  输出图像大小:1×1×120
⑦ 第六层(F6):全连接层
  输入图像大小:1×1×120
  激活函数: ReLU
  神经元个数: 84
  可训练参数: 10164(120×84+84)
  输出图像大小:1×1×84
⑧ 第七层(Output):输出层
  输入图像大小:1×1×84
  激活函数: Softmax
  神经元个数: 10
  可训练参数: 850(84×10+10)
  输出图像大小:1×1×10

3.TensorFlow上实现LeNet-5

(1) 输入数据: MNIST数字手写体数据集(注:原结构针对的是32×32×1的图像,MNIST库上的图像是28×28×1。为什么不用Cifar-10数据集呢?Cifar-10数据集图像大小是32×32×3,计算量太大,故以MNIST库为例)
(2) 编程配置: Python3.5 + TensorFlow1.2.0
(3) 优化器: Adam
(4) 训练集上的batch size: 128
(5) 测试集上的batch size: 10000
(6) TensorFlow实现
① 首先,导入tensorflow库和MNIST数据集

import tensorflow as tf# 导入MNIST数字手写体数据集from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/",one_hot = True)

② 开始定义LeNet-5的整体结构

# 定义LeNet-5的所有层的结构、训练次数、batch sizein_units = 784C1_units = 6S2_units = 6C3_units = 16S4_units = 16C5_units = 120F6_units = 84output_units = 10train_steps = 10000 batch_size =128

③ 输入层

# 定义输入x = tf.placeholder(tf.float32,[None,in_units],name = "x_input")y_= tf.placeholder(tf.float32,[None,output_units],name = "y_input")keep_prob = tf.placeholder(tf.float32)input_data = tf.reshape(x,[-1,28,28,1])

④ 卷积层

# 定义可配置卷积层函数def conv(layer_name,input_x,Filter_size,activation_function = None):    with tf.name_scope("conv_%s" % layer_name):        with tf.name_scope("Filter"):            Filter = tf.Variable(tf.random_normal(Filter_size,stddev = 0.1),dtype = tf.float32,name = "Filter")            tf.summary.histogram('Filter_%s'%layer_name,Filter)        with tf.name_scope("bias_filter"):            bias_filter = tf.Variable(tf.random_normal([Filter_size[3]],stddev = 0.1),dtype = tf.float32,name = "bias_filter")            tf.summary.histogram('bias_filter_%s'%layer_name,bias_filter)        with tf.name_scope("conv"):            conv =  tf.nn.conv2d(input_x,Filter,strides = [1,1,1,1],padding = "SAME")            tf.summary.histogram('conv_%s'%layer_name,conv)        temp_y = tf.add(conv,bias_filter)        with tf.name_scope("output_y"):            if activation_function is None:                output_y = temp_y            else:                output_y = activation_function(temp_y)            tf.summary.histogram('output_y_%s'%layer_name,output_y)        return output_y

⑤ 池化层

# 定义可配置最大池化层函数def max_pool(layer_name,input_x,pool_size):    with tf.name_scope("max_pool_%s" % layer_name):        with tf.name_scope("output_y"):            output_y = tf.nn.max_pool(input_x,pool_size,strides = [1,2,2,1],padding = "SAME")            tf.summary.histogram('output_y_%s'%layer_name,output_y)        return output_y

⑥ 全连接层

# 定义可配置全连接层函数def full_connect(layer_name,input_x,output_num,activation_function = None):    with tf.name_scope("full_connect_%s" % layer_name):        with tf.name_scope("Weights"):            Weights = tf.Variable(tf.random_normal([input_x.shape.as_list()[1],output_num],stddev = 0.1),dtype = tf.float32,name = "weight")            tf.summary.histogram('Weights_%s'%layer_name,Weights)        with tf.name_scope("biases"):            biases = tf.Variable(tf.random_normal([output_num],stddev = 0.1),dtype = tf.float32,name = "biases")            tf.summary.histogram('biases_%s'%layer_name,biases)    output_temp = tf.add(tf.matmul(input_x,Weights) , biases)    with tf.name_scope("output_y"):        if activation_function is None:            output_y = output_temp        else:            output_y = activation_function(output_temp)        tf.summary.histogram('output_y_%s'%layer_name,output_y)    return output_y

⑦ 构建整体结构

# 通过调用函数的形式构建LeNet-5的结构output_layer1 = conv("layer1",input_data,[5,5,1,6],activation_function = tf.nn.relu)output_layer2 = max_pool("layer2",output_layer1,[1,2,2,1])output_layer3 = conv("layer3",output_layer2,[5,5,6,16],activation_function = tf.nn.relu)output_layer4 = max_pool("layer4",output_layer3,[1,2,2,1])output_layer4 = tf.reshape(output_layer4,[-1,tf.cast(output_layer4.shape[1]*output_layer4.shape[2]*output_layer4.shape[3],tf.int32)])output_layer5 = full_connect("layer5",output_layer4,120,activation_function = tf.nn.relu)output_layer6 = full_connect("layer6",output_layer5,84,activation_function = tf.nn.relu)output_layer7 = tf.nn.dropout(output_layer6,keep_prob)with tf.name_scope("output_y"):    y = full_connect("layer7",output_layer7,10,activation_function = tf.nn.softmax)

⑧ 损失函数(交叉熵)、优化器、准确度

# 计算损失函数(交叉熵)、优化器、准确度with tf.name_scope("loss"):    loss = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]),name='cross_entropy')    tf.summary.scalar("loss",loss)with tf.name_scope("train"):    train = tf.train.AdamOptimizer(1e-4).minimize(loss)prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))with tf.name_scope("accuracy"):    accuracy = tf.reduce_mean(tf.cast(prediction,tf.float32),name = 'accuracy')

⑨ 框架搭好了,正式开始计算

# 初始化所有的变量init = tf.global_variables_initializer()# 正式开始导入数据进行计算with tf.Session() as sess:    sess.run(init)    merge = tf.summary.merge_all()    writer = tf.summary.FileWriter('log/LeNet_log',sess.graph)    for i in range(train_steps):        batch = mnist.train.next_batch(batch_size)        sess.run(train,feed_dict = {x:batch[0],y_:batch[1],keep_prob : 0.7})        if i%100 ==0:            accuracy_batch = sess.run(accuracy,feed_dict = {x:batch[0],y_:batch[1],keep_prob : 0.7})            loss_batch = sess.run(loss,feed_dict = {x:batch[0],y_:batch[1],keep_prob : 0.7})            print('after %d train steps the loss on the train dataset is %g and the accuracy on train dataset is %g'%(i,loss_batch,accuracy_batch))            result = sess.run(merge,feed_dict={x:batch[0],y_:batch[1],keep_prob : 0.7})            writer.add_summary(result,i)        if i == train_steps-1:            loss_test = sess.run(loss,feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1})            accuracy_test = sess.run(accuracy,feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1})            print('the loss in test dataset is %g and the accuracy in test dataset is %g'%(loss_test,accuracy_test))

(注:也许有的同学会有疑问,原结构的Padding用的是”VALID”,为什么我这里用的是”SAME”,原因是我们用的是MNIST库,图片太小(28*28),用”VALID”容易导致不收敛,用”SAME”能使图片在后面几层仍然保持一定大小而能提取其有用特征)

(7)运行结果:

![这里写图片描述](http://img.blog.csdn.net/20170825212319049?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvSV9Jcm9uYm95/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast)…![这里写图片描述](http://img.blog.csdn.net/20170825212336483?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvSV9Jcm9uYm95/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast)…![这里写图片描述](http://img.blog.csdn.net/20170825212352124?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvSV9Jcm9uYm95/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast)  **测试集上的准确度:98.2%**(8) **TensorBoard可视化:**   由于我上一篇博客**[那些经典的神经网络结构–MLP](http://blog.csdn.net/i_ironboy/article/details/77547401)**中已经详细讲解了如何进入TensorBoard,故此处不再详细叙述,还不是很熟练的同学可参照我的上一篇博客。

4.总结

  LeNet5特征能够总结为如下几点:
1)卷积神经网络使用三个层作为一个系列: 卷积,下采样(又称池化),非线性
2) 使用卷积提取空间特征
3)使用映射到空间最大下采样(max_pool)
4)ReLU的非线性(x 小于等于0 时为0x 大于0 时为x)
5)多层神经网络(MLP)作为最后的分类器
6)层与层之间的稀疏连接矩阵避免大的计算成本
  总体看来,这个网络是最近大量神经网络架构的起点,并且也给这个领域带来了许多灵感。
参考:简书
作者: Warren_Liu
链接:http://www.jianshu.com/p/e7980ba12b4d

原创粉丝点击