tensorflow1.1/循环神经网络手写数字啊识别
来源:互联网 发布:bat的程序员什么水平 编辑:程序博客网 时间:2024/05/16 12:33
环境:tensorflow 1.1 , python 3 , matplotlib 2.02
#coding:utf-8"""tensorflow 1.1python 3 matplotlib 2.02"""import tensorflow as tfimport input_dataimport numpy as npimport matplotlib.pyplot as plt#设置随机种子tf.set_random_seed(100)np.random.seed(100)BATCH_SIZE = 64TIME_STEP = 28INPUT_SIZE = 28learning_rate = 0.01mnist = input_data.read_data_sets('mnist/',one_hot=True)test_x = mnist.test.images[:2000]test_y = mnist.test.labels[:2000]#查看图片plt.imshow(mnist.test.images[3].reshape((28,28)))plt.title('the picture is %i' %np.argmax(mnist.test.labels[3]),fontdict={'size':16,'color':'red'})plt.show()xs = tf.placeholder(tf.float32,[None,TIME_STEP*INPUT_SIZE])ys = tf.placeholder(tf.int32,[None,10])#输入神经网络前将形状(None,28*28) --->(None,28,28)x = tf.reshape(xs,[-1,TIME_STEP,INPUT_SIZE])#构建循环神经网络#tf.contrib.rnn.BasicLSTMCell()构建循环神经网络的cellrnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)#tf.nn.dynamic_rnn返回outputs和states,其中states包含主要state和次要state#输入神经网络的形状(batch,time_step,input)时间参数不在第一个维度,所以time_major=Falseoutputs,states = tf.nn.dynamic_rnn(rnn_cell,x,initial_state=None,dtype=tf.float32,time_major=False)#将最后一个time_step的输出作为输出output = tf.layers.dense(outputs[:,-1,:],10)#计算lossloss = tf.losses.softmax_cross_entropy(onehot_labels=ys,logits=output)train = tf.train.AdamOptimizer(learning_rate).minimize(loss)#计算accuracy,返回两个值acc和uodate_opaccuracy = tf.metrics.accuracy(labels=tf.argmax(ys,axis=1),predictions=tf.argmax(output,axis=1))[1]with tf.Session() as sess: init = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(init) for step in range(1000): batch_x,batch_y = mnist.train.next_batch(BATCH_SIZE) _,c = sess.run([train,loss],feed_dict={xs:batch_x,ys:batch_y}) if step % 100 == 0: acc = sess.run(accuracy,feed_dict={xs:test_x,ys:test_y}) print('= = = = = = > > > > > >','epoch: ',int(step/100),'train loss : %.4f' %c,'test accuracy: %.3f' %acc)
结果
阅读全文
0 0
- tensorflow1.1/循环神经网络手写数字啊识别
- tensorflow1.1/构建卷积神经网络识别手写数字
- 使用神经网络识别手写数字
- 利用神经网络识别手写数字
- 初识神经网络--识别手写数字
- 神经网络:简单手写数字识别神经网络
- 第1章使用神经网络识别手写数字
- 第一章 用神经网络来识别手写数字(1)
- tensorflow1.1/构建双向神经网络识别mnist
- tensorflow1.1/构建卷积神经网络识别文本
- 神经网络用于手写数字识别更新版
- 卷积神经网络(cnn) 手写数字识别
- 卷积神经网络CNN 手写数字识别
- 用BP人工神经网络识别手写数字
- 卷积神经网络(cnn) 手写数字识别
- 神经网络实现手写数字识别(MNIST)
- 神经网络-tensorflow实现mnist手写数字识别
- 使用神经网络识别手写数字--原理部分
- ++ --
- Android RxJava2.0的简单使用
- Spark成长之路(5)-消息队列
- day08-JavaWeb之http协议request-response
- Android 解决java.lang.IllegalStateException: Can not perform this action after onSaveInstanceState异常
- tensorflow1.1/循环神经网络手写数字啊识别
- URL长链接转短链接
- [译]The Python Tutorial#8. Errors and Exceptions
- 斐波那契数列
- Java语言高编——面向对象-抽象类
- 清除Adnroid (安卓)手机微信浏览器的缓存
- bugku ctf 一段base64 wirteup
- ListView的使用
- 详解数据库中的视图、临时表