MNIST(二):基于CNN的mnist识别
来源:互联网 发布:网络销售模式有哪些 编辑:程序博客网 时间:2024/05/17 08:53
在自学tensorflow的时候,经常会遇到这样的问题:
1.为啥这里用这个函数,还有这个函数在哪里定义的,我怎么查询api文档
2.为什么我看了同样是写mnist的代码,为什么实现的方法会有很大的区别
因为tensorflow现在的更新比较频繁,版本更替很快,所以很正常会看到实现的方法不同,还有一点就是各人的写代码风格不同,但是如果弄清楚实现的那几个步骤,其实也能很好理解。
至于如何查api文档,可以访问http://devdocs.io/ 页面向下拖就可以找到tensorflow,并且api文档是可以下载的,保存在浏览器的缓存中,没有网的时候也可以访问。
下面正题,利用CNN实现mnist识别:
- tensorflow里面内置了处理mnist的各种函数,方便我们操作。所以不必我们进行数据的处理
- 如果我们想显示mnist数据集里的一个图片,怎么操作呢?
print(mnist.train.images.shape) # (55000, 28 * 28)print(mnist.train.labels.shape) # (55000, 10)plt.imshow(mnist.train.images[0].reshape((28, 28)), cmap='gray')plt.title('%i' % np.argmax(mnist.train.labels[0])); plt.show()
2 . 接下来就是实现的过程:(这里借鉴了莫烦大佬的代码)
这里用了两层的conv加上后面的全连接
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as npimport matplotlib.pyplot as plttf.set_random_seed(1)np.random.seed(1)BATCH_SIZE = 50LR = 0.001 # learning ratemnist = input_data.read_data_sets('./mnist', one_hot=True) # they has been normalized to range (0,1)test_x = mnist.test.images[:2000]test_y = mnist.test.labels[:2000]# plot one exampleprint(mnist.train.images.shape) # (55000, 28 * 28)print(mnist.train.labels.shape) # (55000, 10)plt.imshow(mnist.train.images[0].reshape((28, 28)), cmap='gray')plt.title('%i' % np.argmax(mnist.train.labels[0])); plt.show()tf_x = tf.placeholder(tf.float32, [None, 28*28]) / 255.image = tf.reshape(tf_x, [-1, 28, 28, 1]) # (batch, height, width, channel)tf_y = tf.placeholder(tf.int32, [None, 10]) # input y# CNNconv1 = tf.layers.conv2d( # shape (28, 28, 1) inputs=image, filters=16, kernel_size=5, strides=1, padding='same', activation=tf.nn.relu) # -> (28, 28, 16)pool1 = tf.layers.max_pooling2d( conv1, pool_size=2, strides=2,) # -> (14, 14, 16)conv2 = tf.layers.conv2d(pool1, 32, 5, 1, 'same', activation=tf.nn.relu) # -> (14, 14, 32)pool2 = tf.layers.max_pooling2d(conv2, 2, 2) # -> (7, 7, 32)flat = tf.reshape(pool2, [-1, 7*7*32]) # -> (7*7*32, )output = tf.layers.dense(flat, 10) # output layerloss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output) # compute costtrain_op = tf.train.AdamOptimizer(LR).minimize(loss)accuracy = tf.metrics.accuracy( # return (acc, update_op), and create 2 local variables labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]sess = tf.Session()init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_opsess.run(init_op) # initialize var in graphfor step in range(600): b_x, b_y = mnist.train.next_batch(BATCH_SIZE) _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y}) if step % 50 == 0: accuracy_, flat_representation = sess.run([accuracy, flat], {tf_x: test_x, tf_y: test_y}) print('Step:', step, '| train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)# print 10 predictions from test datatest_output = sess.run(output, {tf_x: test_x[:10]})pred_y = np.argmax(test_output, 1)print(pred_y, 'prediction number')print(np.argmax(test_y[:10], 1), 'real number')
阅读全文
0 0
- MNIST(二):基于CNN的mnist识别
- cnn-mnist手写识别
- cnn 手写数字识别 mnist
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
- 基于Tensorflow的MNIST手写数字识别(一)
- 基于Tensorflow的MNIST手写数字识别(三)
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST数字识别
- 基于tensorflow的MNIST手写字识别
- Tensorflow学习笔记(二):利用CNN实现手写数字(mnist)识别
- Tensorflow系列之(二):详解CNN识别MNIST手写数字集
- MNIST手写数字的识别——CNN篇
- TensorFlow CNN 以库函数的方式实现MNIST手写识别
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
- Tensoflow+CNN实现简单的mnist手写数字识别
- Tensorflow入门二 mnist识别(一)
- Tensorflow入门三 mnist识别(二)
- 两个int(32位)整数m和n的二进制表达中,有多少个位(bit)不同?
- spring-data-redis模块详解
- cut操作详解
- 利用AJAX后台查询数据库返回json,前台生成表格
- 物理层的功能与特性
- MNIST(二):基于CNN的mnist识别
- n个字符全排列
- ubuntu添加环境变量
- 简单的数据结构(总结)只涉及用数组或链表实现
- CAFFE学习笔记(五)用caffe跑自己的jpg数据
- html拾遗第3天,history
- 【Java】Enum实践
- PID算法的简单C语言实现
- 立即执行函数