nn优化研究(二)
来源:互联网 发布:大数据o2o概念股龙头 编辑:程序博客网 时间:2024/06/05 02:54
在优化问题中我碰到一个很奇怪的现象,在起初迭代的时候准确率还能到95%,迭代几次后就变为了9.8%。不知道是什么鬼,等有结论了再更新。先复现一下程序
from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfsess = tf.InteractiveSession()mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#参数(每个批次数据量的大小) batch_size = 100 #计算共有多少个批次 n_batch = mnist.train.num_examples // batch_size #定义两个placeholder x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) #创建神经网络层 w1 = tf.Variable(tf.random_normal([784,100])) b1 = tf.Variable(tf.zeros([100])) l1 = tf.nn.tanh(tf.matmul(x,w1)+b1) w2 = tf.Variable(tf.random_normal([100,10])) b2 = tf.Variable(tf.zeros([10])) prediction = tf.nn.softmax(tf.matmul(l1,w2)+b2) #代价函数 #loss = tf.reduce_mean(tf.square(prediction-y)) loss =-tf.reduce_sum(y*tf.log(prediction))#优化算法 train_step = tf.train.GradientDescentOptimizer(0.02).minimize(loss) #准确率结果计算 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #运行模型 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(100): for batch in range(n_batch): batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) trainacc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels}) if (epoch+1)%10 == 0: print ("epoch= "+str(epoch)+" accuray= "+str(acc) + " train accuray="+str(trainacc))epoch= 9 accuray= 0.966564 train accuray=0.9398epoch= 19 accuray= 0.991164 train accuray=0.9546epoch= 29 accuray= 0.998891 train accuray=0.954epoch= 39 accuray= 0.999764 train accuray=0.9523epoch= 49 accuray= 0.0989818 train accuray=0.098epoch= 59 accuray= 0.0989818 train accuray=0.098epoch= 69 accuray= 0.0989818 train accuray=0.098epoch= 79 accuray= 0.0989818 train accuray=0.098epoch= 89 accuray= 0.0989818 train accuray=0.098epoch= 99 accuray= 0.0989818 train accuray=0.098
0 0
- nn优化研究(二)
- nn优化研究
- tensorflow(二):tf.nn.conv2d
- torch学习(二) nn类结构-Module
- 应用机器学习(二):k-NN 分类器
- 智力研究(二)
- 智力题研究(二)
- 智力题研究(二)
- CEF研究(二)
- Acoustic研究(二)
- CEF研究(二)
- NN(BP)算法
- tf.nn.sparse_softmax_cross_entropy_with_logits()
- nn
- nn
- nn
- 神经网络的前向传播和误差反向传播(NN,RNN,LSTM)(二)
- 优化GeoServer的运行------GeoServer研究随笔二
- 批量复制文件并改名
- 【GDOI2017第四轮模拟day2】绝版题
- LA 3708
- [JZOJ5088]最小边权和
- phpmysql登陆报错 #1862
- nn优化研究(二)
- Beam学习笔记(3):Flink Streaming Pipeline Translator
- HDU 3697 A hard Aoshu Problem (搜索)
- 《Redis实战》读后感
- Windows下安装TensorFlow
- 浅析name==null, "".equals(name)和name.length==0三者的区别
- 简介Opencv在Python中的使用
- c++单元测试框架Catch
- 顺序表(约瑟夫环)