mnist_rnn
来源:互联网 发布:淘宝查购买顺序 编辑:程序博客网 时间:2024/06/17 05:26
from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfdata_dir="mnist"mnist = input_data.read_data_sets(data_dir, one_hot=True)lr = 0.001training_iters = 100000batch_size = 128#network parametersn_inputs = 28n_steps = 28n_hidden_units = 128n_classes = 10#placeholderx = tf.placeholder(tf.float32, [None, n_steps, n_inputs])y = tf.placeholder(tf.float32, [None, n_classes])#weightsweights = { #(28, 128) 'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])), #(128, 10) 'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))}biases = { #(128, ) 'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])), 'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))}#RNNdef RNN(X, weights, biases): #X ==> [ 128 batch * 28 steps, 28 inputs ] X = tf.reshape(X, [-1, n_inputs]) #into hidden #X_in = (128 batch * 28 steps, 128 hidden) X_in = tf.matmul(X, weights['in']) + biases['in'] #X_in --> (128 batch, 28 steps 128 hidden) X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units]) #basic LSTM Cell lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True) #lstm cell: (c_state, h_state) init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) #dynamic_rnn accept tensor(batch, steps, inputs) or (steps, batch, inputs) as X_in outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False) results = tf.matmul(final_state[1], weights['out']) + biases['out'] return results#loss functionpred = RNN(x, weights, biases)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))#optimizertrain_op = tf.train.AdamOptimizer(lr).minimize(cost)#accuracycorrect_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))#trainwith tf.Session() as sess: sess.run(tf.global_variables_initializer()) step = 0 while step * batch_size < training_iters: batch_xs, batch_ys = mnist.train.next_batch(batch_size) batch_xs = batch_xs.reshape( [batch_size, n_steps, n_inputs] ) sess.run([train_op], feed_dict={x: batch_xs, y: batch_ys}) if step %20 == 0: print sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys}) step += 1
0.210938
0.570312
0.757812
0.835938
0.882812
0.882812
0.890625
0.890625
0.898438
0.867188
0.914062
0.929688
0.914062
0.9375
0.898438
0.953125
0.945312
0.882812
0.945312
0.96875
0.9375
0.953125
0.992188
0.96875
0.976562
0.953125
0.945312
0.976562
0.953125
0.96875
0.945312
0.960938
0.945312
0.984375
0.953125
0.96875
0.992188
0.976562
0.960938
0.960938
阅读全文
0 0
- mnist_rnn
- 多版本python如何使用pip
- win10 ssd+普通硬盘安装centos7 无法引导启动linux
- 软件开发过程大观——软件开发过程改进为什么能帮助软件质量提升?
- gulp安装以及使用
- python通过sublime运行不同版本python
- mnist_rnn
- SDUT-3402 数据结构实验之排序五:归并求逆序数
- Access迁移Mysql错误07002:1-3010:[Microsoft][ODBC Microsoft
- python使用sublime中文输出问题
- JVM内存回收策略
- python爬虫中文输出问题以及不即时输出问题
- maven的阿里镜像
- Could not read from System Tables. You must grant SELECT access on all system tables for the databas
- python使用xlwt和xlrd模块操作excel