暑期 tensorflow 小练 mnist
来源:互联网 发布:域名需要购买解析吗 编辑:程序博客网 时间:2024/06/09 07:43
暑期 tensorflow 小练 mnist
1. 下载
import numpy as npimport tensorflow as tfimport matplotlib.pyplot as plt# 下面方法二选一# 直接 API下载# from tensorflow.examples.tutorials.mnist import input_data# 文件路径下获取# import input_dataprint ("packs loaded.")
print ("download and extract mnist dataset...")# 独热编码(One-Hot Encoding)mnist = input_data.read_data_sets('data/', one_hot = True)printprint ("type of mnist is %s" % (type(mnist)))print ("number of train data is %d" % (mnist.train.num_examples))print ("number of test data is %d" % (mnist.test.num_examples))
独热编码:
独热码的作用:One-hot encoding是只存在一个1,其余全为0的n位序列。也可以称它为二元向量,二元就是里面只有0和1。通常被用来描述一个状态机的某个状态。
2. 构建网络拓扑
- 建立网络结构
# 网络拓扑 network topologyn_hidden1 = 256n_hidden2 = 128n_input = 784n_classes = 10# input outputx = tf.placeholder("float", [None, n_input])# batchsize的大小 range不限制可以为任意值,故为 Noney = tf.placeholder("float", [None, n_classes])# 网络参数stddev = 0.1weights = { # [n_input, n_hidden1]为过渡矩阵w1的维度 'w1': tf.Variable(tf.random_normal([n_input, n_hidden1],stddev = stddev)), 'w2': tf.Variable(tf.random_normal([n_hidden1, n_hidden2],stddev = stddev)), 'out': tf.Variable(tf.random_normal([n_hidden2, n_classes],stddev = stddev))}biases = { 'b1': tf.Variable(tf.random_normal([n_hidden1])), 'b2': tf.Variable(tf.random_normal([n_hidden2])), 'out': tf.Variable(tf.random_normal([n_classes]))}print ("network ready!")
3. 网络训练相关的参数
# MLP定义def multilayer_perceptron(_X, _weights, _biases): # x * w + b layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X, weights['w1']), _biases['b1'])) layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['w2']), _biases['b2'])) return (tf.matmul(layer_2, weights['out']) + _biases['out'])# 预测pred = multilayer_perceptron(x, weights, biases)# loss & opt# 因为是分类任务 损失函数选择 softmax ;计算最后一层是softmax层的cross entropy交叉熵# tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)# 第一个参数logits:神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes]# 第二个参数labels:实际的标签,大小同上cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))optm = tf.train.GradientDescentOptimizer(learning_rate = 0.001).minimize(cost)# 精度值 argmax(f(x))是使得 f(x)取得最大值 f(x0)所对应的变量 x0# tf.equal(A, B)是对比这两个矩阵的对应元素,如果相等就返回True,反之返回False,返回的值的矩阵维度和A一样# tf.cast(x, dtype, name=None) 此函数是类型转换函数corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accr = tf.reduce_mean(tf.cast(corr, "float"))# initializerinit = tf.global_variables_initializer()print ("functions ready")
4. 网络训练
# 训练总共20轮training_epochs = 20# 每轮选取100个数据训练batch_size = 100# 训练结果展示比例1:4,结果共有epoch/display_step = 20/4 = 5display_step = 4# launch the graphsess = tf.Session()sess.run(init)# optimizefor epoch in range(training_epochs): avg_cost = 0. #batch总数为总数据量/一个batch的容量 total_batch = int(mnist.train.num_examples/batch_size) #iteration循环 for i in range(total_batch): # next_batch函数针对于内部提供数据集 如果是自己的数据集需要自己写函数 batch_xs, batch_ys = mnist.train.next_batch(batch_size) feeds = {x: batch_xs, y: batch_ys} sess.run(optm, feed_dict = feeds) avg_cost += sess.run(cost, feed_dict = feeds) avg_cost = avg_cost/total_batch # display if(epoch+1) % display_step == 0: print ("epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost)) feeds = {x: batch_xs, y: batch_ys} train_acc = sess.run(accr, feed_dict = feeds) print ("train accuracy: %.3f" % (train_acc)) feeds = {x: mnist.test.images, y: mnist.test.labels} test_acc = sess.run(accr, feed_dict = feeds) print ("test accuracy: %.3f" % (test_acc))print ("optimization finished!")
结语
代码细节需要好好品味,其中训练的epoch、batch、display_step需要区分概念,并理解在训练过程中怎么安排的步骤。
阅读全文
0 0
- 暑期 tensorflow 小练 mnist
- 暑期 tensorflow+CNN+mnist
- tensorflow-mnist数据集训练
- tensorflow Mnist
- tensorflow +mnist
- tensorflow mnist
- MNIST tensorflow
- tensorflow 入门小例子(mnist手写数字识别)
- TensorFlow的MNIST手写数字还原小程序
- 机器学习笔记6:TensorFlow入门之MNIST数据集训练
- 基于Tensorflow, OpenCV. 使用MNIST数据集训练卷积神经网络模型,用于手写数字识别
- MNIST数据集训练
- Tensorflow mnist basic
- Tensorflow: Logistic Regression Mnist
- Tensorflow之Mnist入门
- tensorflow CNN for mnist
- TensorFlow之深入MNIST
- TensorFlow MNIST案例代码
- uva 548 Tree
- RocketMQ(三)——HelloWorld
- 18.图像用户界面入门:EasyGui
- LeetCode 86 Partition List (Python详解及实现)
- 地震中哪些机器人能参与应急救援
- 暑期 tensorflow 小练 mnist
- 百度地图点击地点显示经纬度并且转换为百度地址及添加控件
- Unity调用Android的Java方法
- js中的那些数组处理函数区别及用法总结
- [C++11]左值、右值、左值引用、右值引用小结
- 【字符串入门专题1】F
- BigDecimal 加减乘除
- dubbo demo
- C++中指针和引用的区别