基于TensorFlow的MNIST(手写图像识别)的一点经验

来源:互联网 发布:血源诅咒帅哥捏脸数据 编辑:程序博客网 时间:2024/06/18 15:34

最近要弄个简单的用于图像识别和提取图像特征值的 简单的bp神经网络,经过朋友推荐选择了TensorFlow(以下简称tf),tf是一款谷歌开源的神经网络库,功能非常强大,虽然笔者能力有限还没有仔细探索。
依据官网的教程一步步安装了tf库,注意,官网的教程是基于Python2.7的,并且需要最linux环境下运行,但是最新版本的tf已经支持py3.5了,并且在win环境下也能跑,笔者对linux不熟悉,所以选取了win作为开发环境。另推荐一款较好用的库管理文件,Anacondar,可以有效的管理不同的库和不同版本的Python。

闲话少说,进入正题。
关于MNIST的代码网上有许多,或者照着官网给的例子一步步敲也能敲出来,不需要什么数学功底,就是有细节要注意。
官网给的中文文档其实是有些过时了的,它给的例子里要求我们下载一段叫做 input_data.py的程序,这段程序的作用是将训练集导入,然后又给了一段代码,from …..input_data import *,笔者一开始想的是我的input_data已经在同一个目录下了,可不可以直接导入呢?就像这样 import inport_data ,结果总是报错,后来发现,由于tf版本更新的关系,原来官网提供的那段 input_data,已经过时了,api接口对不上了,新版本的tf中集成了这个函数 from tensorflow.examples.tutorials.mnist import input_data,
这才是新版本的接口。
接下来就可以直接运行了。附上笔者的代码
`#--coding:utf-8--
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(“MNIST_data/”, one_hot=True)

import tensorflow as tf

x = tf.placeholder(“float”,[None,784])

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

y_ = tf.placeholder(“float”, [None,10])

cross_entropy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

init = tf.initialize_all_variables()

sess = tf.Session()

sess.run(init)

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, “float”))

print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))`

另外还所需四个训练集,在官网有的下载。

0 0
原创粉丝点击