nn优化研究(二)

来源:互联网 发布:大数据o2o概念股龙头 编辑:程序博客网 时间:2024/06/05 02:54

在优化问题中我碰到一个很奇怪的现象,在起初迭代的时候准确率还能到95%,迭代几次后就变为了9.8%。不知道是什么鬼,等有结论了再更新。先复现一下程序

from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfsess = tf.InteractiveSession()mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#参数(每个批次数据量的大小)  batch_size = 100  #计算共有多少个批次  n_batch = mnist.train.num_examples // batch_size    #定义两个placeholder  x = tf.placeholder(tf.float32,[None,784])  y = tf.placeholder(tf.float32,[None,10])    #创建神经网络层  w1 = tf.Variable(tf.random_normal([784,100]))  b1 = tf.Variable(tf.zeros([100]))  l1 = tf.nn.tanh(tf.matmul(x,w1)+b1)    w2 = tf.Variable(tf.random_normal([100,10]))  b2 = tf.Variable(tf.zeros([10]))  prediction = tf.nn.softmax(tf.matmul(l1,w2)+b2)    #代价函数  #loss = tf.reduce_mean(tf.square(prediction-y))  loss =-tf.reduce_sum(y*tf.log(prediction))#优化算法  train_step = tf.train.GradientDescentOptimizer(0.02).minimize(loss)    #准确率结果计算  correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))  accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))  #运行模型  with tf.Session() as sess:      sess.run(tf.global_variables_initializer())      for epoch in range(100):          for batch in range(n_batch):              batch_xs,batch_ys = mnist.train.next_batch(batch_size)              sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})          trainacc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})         acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})        if (epoch+1)%10 == 0:              print ("epoch= "+str(epoch)+" accuray= "+str(acc) + " train accuray="+str(trainacc))epoch= 9 accuray= 0.966564 train accuray=0.9398epoch= 19 accuray= 0.991164 train accuray=0.9546epoch= 29 accuray= 0.998891 train accuray=0.954epoch= 39 accuray= 0.999764 train accuray=0.9523epoch= 49 accuray= 0.0989818 train accuray=0.098epoch= 59 accuray= 0.0989818 train accuray=0.098epoch= 69 accuray= 0.0989818 train accuray=0.098epoch= 79 accuray= 0.0989818 train accuray=0.098epoch= 89 accuray= 0.0989818 train accuray=0.098epoch= 99 accuray= 0.0989818 train accuray=0.098



0 0
原创粉丝点击