TensorFlow入门(二)简单前馈网络实现 mnist 分类
来源:互联网 发布:数据迁移重要程度 编辑:程序博客网 时间:2024/06/05 03:30
在本教程中,我们来实现一个非常简单的两层全连接网络来完成MNIST数据的分类问题。
输入[-1,28*28], FC1 有 1024 个neurons, FC2 有 10 个neurons。这么简单的一个全连接网络,结果测试准确率达到了 0.98。还是非常棒的!!!
import numpy as npimport tensorflow as tf# 设置按需使用GPUconfig = tf.ConfigProto()config.gpu_options.allow_growth = Truesess = tf.InteractiveSession(config=config)
1. 导入数据
# 用tensorflow 导入数据from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gz
print 'training data shape ', mnist.train.images.shapeprint 'training label shape ', mnist.train.labels.shape
training data shape (55000, 784)training label shape (55000, 10)
2. 构建网络
# 权值初始化def weight_variable(shape): # 用正态分布来初始化权值 initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial)def bias_variable(shape): # 本例中用relu激活函数,所以用一个很小的正偏置较好 initial = tf.constant(0.1, shape=shape) return tf.Variable(initial)# input_layerX_ = tf.placeholder(tf.float32, [None, 784])y_ = tf.placeholder(tf.float32, [None, 10])# FC1W_fc1 = weight_variable([784, 1024])b_fc1 = bias_variable([1024])h_fc1 = tf.nn.relu(tf.matmul(X_, W_fc1) + b_fc1)# FC2W_fc2 = weight_variable([1024, 10])b_fc2 = bias_variable([10])y_pre = tf.nn.softmax(tf.matmul(h_fc1, W_fc2) + b_fc2)
3. 训练和评估
# 1.损失函数:cross_entropycross_entropy = -tf.reduce_sum(y_ * tf.log(y_pre))# 2.优化函数:AdamOptimizer, 优化速度要比 GradientOptimizer 快很多train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)# 3.预测结果评估# 预测值中最大值(1)即分类结果,是否等于原始标签中的(1)的位置。argmax()取最大值所在的下标correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.arg_max(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 开始运行sess.run(tf.global_variables_initializer())# 这大概迭代了不到 10 个 epoch, 训练准确率已经达到了0.98for i in range(5000): X_batch, y_batch = mnist.train.next_batch(batch_size=100) train_step.run(feed_dict={X_: X_batch, y_: y_batch}) if (i+1) % 200 == 0: train_accuracy = accuracy.eval(feed_dict={X_: mnist.train.images, y_: mnist.train.labels}) print "step %d, training acc %g" % (i+1, train_accuracy) if (i+1) % 1000 == 0: test_accuracy = accuracy.eval(feed_dict={X_: mnist.test.images, y_: mnist.test.labels}) print "= " * 10, "step %d, testing acc %g" % (i+1, test_accuracy)
step 200, training acc 0.937364step 400, training acc 0.965818step 600, training acc 0.973364step 800, training acc 0.977709step 1000, training acc 0.981528= = = = = = = = = = step 1000, testing acc 0.9688step 1200, training acc 0.988437step 1400, training acc 0.988728step 1600, training acc 0.987491step 1800, training acc 0.993873step 2000, training acc 0.992527= = = = = = = = = = step 2000, testing acc 0.9789step 2200, training acc 0.995309step 2400, training acc 0.995455step 2600, training acc 0.9952step 2800, training acc 0.996073step 3000, training acc 0.9964= = = = = = = = = = step 3000, testing acc 0.9778step 3200, training acc 0.996709step 3400, training acc 0.998109step 3600, training acc 0.997455step 3800, training acc 0.995055step 4000, training acc 0.997291= = = = = = = = = = step 4000, testing acc 0.9808step 4200, training acc 0.997746step 4400, training acc 0.996073step 4600, training acc 0.998564step 4800, training acc 0.997946step 5000, training acc 0.998673= = = = = = = = = = step 5000, testing acc 0.98
0 0
- TensorFlow入门(二)简单前馈网络实现 mnist 分类
- TensorFlow入门(二)简单前馈网络实现 mnist 分类
- tensorflow实现简单卷积网络进行mnist分类
- TensorFlow入门(三)多层 CNNs 实现 mnist分类
- TensorFlow入门(三)多层 CNNs 实现 mnist分类
- TensorFlow实现 mnist分类
- 深度学习框架TensorFlow学习(二)----简单实现Mnist
- TensorFlow学习笔记(二)MNIST入门
- Tensorflow入门二 mnist识别(一)
- Tensorflow入门三 mnist识别(二)
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
- 深入浅出TensorFlow(二):TensorFlow解决MNIST问题入门
- TensorFlow学习笔记(4)----完整的工程示例:全连接前馈网络识别MNIST
- 代码,逻辑回归(logistic_regression)实现mnist分类(TensorFlow实现)
- Tensorflow MNIST机器学习入门 分类学习
- tensorflow实战1:lstm实现mnist分类
- tensorflow实现线性分类之MNIST
- 提高mysql千万级大数据SQL查询优化30条经验
- 有趣的开源软件:语音识别工具Kaldi (二)
- 接触Matlab10年后的一个总结,随时使用Matlab要掌握的一些要点
- 亲历者讲述一个程序员如何变成精神病人的
- Groovy 使用完全解析
- TensorFlow入门(二)简单前馈网络实现 mnist 分类
- 解决在安装了jdk1.6.0_26,又安装JDK1.8.2后导致iReport无法启动加载的问题
- 抄袭某神的自定义Button
- input标签没有</input>,不然会有问题
- 算法提高 矩阵乘法
- 1.txt linux 用法 和jar包执行,kafka执行,flume
- Tomcat中配置JNDI数据源详解
- Laravel实现多个视图共享相同的数据
- Mysq创建触发器