Tensorflow框架下识别手写字神经网络代码

来源:互联网 发布:网络泛在化 编辑:程序博客网 时间:2024/05/01 22:45

不借助任何架构的神经网络代码在代码可读性上能够很好的表达出神经网络代码是如何工作的,但是代码运行效率却很低.或者说对硬件的要求很高,因为python语言的运行效率很低.
Google的tensorflow架构很好的在硬件设备上搭建神经网络的代码,该架构在各个开源社区有无数教程.可以去社区了解tensorflow的架构与基础.
(一) Tensorflow加载数据集
Tensorflow数据集的加载进行和模块化的打包:

from tensorflow.examples.tutorials.mnist import input_dataif os.path.exists('/ysk/code/tensorflowcode/MNIST_data/'):    mnist = input_data.read_data_sets("/ysk/code/tensorflowcode/MNIST_data/", one_hot=True)else:    print("file not exist!")

以上返回一个mnist,包括60000行的训练数据集和10000行的测试数据集.训练集中,训练的图片叫做mnist.train.images,训练的标签叫做mnist.train.labels
(二) 构建计算图

#x表示输入,行数代表训练图片张数,784为一个图片的像素,因此每一行为一个训练的图片像素x = tf.placeholder("float", [None, 784])#W代表第一层之间的权重,这里的权重的维度与不使用架构的维度是行列相反的W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))#这里的y是对输出层进行了softmax的分类.y = tf.nn.softmax(tf.matmul(x, W) + b)#这里的y_表示训练集中的标签,是真是的数字y_ = tf.placeholder("float", [None, 10])#代价函数使用的是交叉熵代价函数,简单介绍交叉熵代价函数表示的意义为利用预测输出(y)来表示真正的数字(y_)的困难程度,也就是说使用我们经过神经网络训练的输出来与真实的标签数字进行对比.交叉熵的值越大,说明用训练出来的数据表示真实数据的困难程度就越大,因此神经网络的主要目标就是降低交熵的数值.cross_entropy = -tf.reduce_sum(y_*tf.log(y))#训练采用的算法可以自己选择,参考博客`https://www.cnblogs.com/ranjiewen/p/5938944.html`,该博客说明了各种梯度下降函数的含义以及梯度下降的快慢.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#train_step = tf.train.AdadeltaOptimizer(0.01).minimize(cross_entropy)#train_step = tf.train.RMSPropOptimizer(0.01).minimize(cross_entropy)#train_step = tf.train.AdagradOptimizer(0.01).minimize(cross_entropy)#tensorflow中一定要将所有的变量进行初始化,只有初始化后的节点才能被应用.但是tensorflow中初始化函数使用的语句为下面:init_op = tf.global_variables_initializer()sess = tf.Session()sess.run(init_op)for i in range(1000):    batch_xs, batch_ys = mnist.train.next_batch(100)    sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys})    #equal语句是将输出的[100, 10]的矩阵与标签的[100, 10]的矩阵进行比较,返回的矩阵是如:[true, false, false]形式的矩阵.    cross_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))    #因为生成的矩阵为上面个的形式,所以需要将矩阵形式转换为真正的数值形式    accuracy = tf.reduce_mean(tf.cast(cross_prediction, "float"))    #在训练完以后要对神经网络的准确性进行测试.    test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels})    if i % 10 == 0:        print("After %d steps trainging, the accuracy is %g " %(i, test_acc))