Learning Tensorflow (4)
来源:互联网 发布:linux yum jdk 1.8 编辑:程序博客网 时间:2024/05/22 05:24
搞了那么多准备工作,这节开始来训练一个模型吧,官网上有个关于手写识别完整的例子,模型预测率大概是 91% , 本来打算换一个数据集来玩,结果发现预测率出奇的低(可能是数据量不够,杀鸡用了牛刀),也许是我开打的方式不对. (含泪)
先引入两个损失函数:
– 交叉熵:
– 残差平方和:
官方文档给出例子是采用交叉熵作为损失函数来训练的模型,这里我再引入残差平方和作为损失函数与原交叉熵训练的模型对一个对比参照
import matplotlib.pyplot as pltimport seaborn as snsimport input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)import tensorflow as tf# 训练X数据集占位符x = tf.placeholder("float", [None, 784])# 初始化权重和偏置W = tf.Variable(tf.zeros([784, 10]))b = tf.Variable(tf.zeros([10]))# 构建softmax 回归模型y = tf.nn.softmax(tf.matmul(x, W) + b)# 训练y占位符y_ = tf.placeholder("float", [None, 10])# 损失函数cross_entropy = -tf.reduce_sum(y_*tf.log(y))sum_of_square = tf.reduce_sum(tf.pow((y_ - y) , 2))# 梯度下降train_step_1 = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)train_step_2 = tf.train.GradientDescentOptimizer(0.01).minimize(sum_of_square)
代码中设置两种损失函数下的训练模型的方式,分别为交叉熵与平方损失
# 初始化准确率和迭代次数train_times = [1000, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, 55000]Accuracy_1 = []Accuracy_2 = []for i in train_times: # 初始化 init = tf.initialize_all_variables() # 启动图 sess = tf.Session() sess.run(init) # 构建预测图 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # 以交叉熵训练模型 for j in range(i): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step_1, feed_dict={x: batch_xs, y_: batch_ys}) Accuracy_1.append(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) sess.close() # 初始化 init = tf.initialize_all_variables() # 启动图 sess = tf.Session() sess.run(init) # 以平方和训练模型 for j in range(i): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step_2, feed_dict={x: batch_xs, y_: batch_ys}) Accuracy_2.append(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) sess.close()
上代码预设不同迭代次数下得到预测率,将结果存入在list当中,方便绘制趋势图来观察最佳迭代次数
# 绘图展示sns.set_style("darkgrid")plt.plot(train_times, Accuracy_1, color="deepskyblue")plt.plot(train_times, Accuracy_2, color="lightgreen")plt.title("Accuracy with train_times")plt.xlabel("train_times")plt.ylabel("Accuracy")plt.show()ax.plot_trisurf(x_data[0], x_data[1], y_[0], color="blue")plt.show()
如图示,蓝色线为交叉熵训练模型随着迭代次数变化曲线,绿色线为平方损失随着迭代次数变化的曲线,从曲线观察,从迭代次数大于10000次开始,平方损失下模型的预测率要好于交叉熵下模型预测率,平方损失模型预测变化比较稳定,而交叉熵训练的模型变化波动比较大。给出一个表格来看数据变化。
0表示迭代次数,1表示交叉熵训练模型,2表示平方损失训练模型。
如表示交叉熵最大预测值只有0.922, 而平方损失高于0.93,由此可见平方损失作为损失函数来训练模型效果更好,比官网给出的例子的模型高了一个百分点(哈哈)。
阅读全文
0 0
- Learning Tensorflow (4)
- learning tensorflow
- deep learning on tensorflow
- Tensorflow Learning note1
- Learning Tensorflow (1)
- Learning Tensorflow (2)
- Learning Tensorflow (3)
- Learning Tensorflow (5)
- Notes on learning tensorflow
- Deep Learning algorithms with TensorFlow
- TensorFlow-3-TensorBoard: Visualizing Learning
- 【Tensorflow slim】slim learning包
- 【Deep Learning】Tensorflow MNIST测试
- TensorFlow学习笔记9----TensorFlow Wide & Deep Learning Tutorial
- Deep Learning-TensorFlow (4) CNN卷积神经网络_CIFAR-10进阶图像分类模型(上)
- tensorflow线性模型以及Wide deep learning
- tensorflow线性模型以及Wide deep learning
- 学习笔记:TensorFlow Wide & Deep Learning Tutorial
- java解析json转Map
- CC2640之OAD固件升级(内置Flash)手动配置ImageB
- Android中的Intent
- clipboard.js基本使用
- SVM原理
- Learning Tensorflow (4)
- 一个星期三的下午
- 二叉搜索树的后序遍历序列
- 自定义annotation在Android中的应用
- 国家雷霆出击整治网络内容平台,企业需初心自持
- 完美字符串
- 栈的应用——表达式求值
- Android AIDL Binder框架浅析
- MySQL的引擎之MyISAM和InnoDB