mnist_rnn

来源:互联网 发布:淘宝查购买顺序 编辑:程序博客网 时间:2024/06/17 05:26

from tensorflow.examples.tutorials.mnist import input_dataimport tensorflow as tfdata_dir="mnist"mnist = input_data.read_data_sets(data_dir, one_hot=True)lr = 0.001training_iters = 100000batch_size = 128#network parametersn_inputs = 28n_steps = 28n_hidden_units = 128n_classes = 10#placeholderx = tf.placeholder(tf.float32, [None, n_steps, n_inputs])y = tf.placeholder(tf.float32, [None, n_classes])#weightsweights = {    #(28, 128)    'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),    #(128, 10)    'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))}biases = {    #(128, )    'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),    'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))}#RNNdef RNN(X, weights, biases):   #X ==> [ 128 batch * 28 steps, 28 inputs ]   X = tf.reshape(X, [-1, n_inputs])   #into hidden   #X_in = (128 batch * 28 steps, 128 hidden)   X_in = tf.matmul(X, weights['in']) + biases['in']   #X_in --> (128 batch, 28 steps 128 hidden)   X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])   #basic LSTM Cell   lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True)   #lstm cell: (c_state, h_state)   init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)   #dynamic_rnn accept tensor(batch, steps, inputs) or (steps, batch, inputs) as X_in   outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)   results = tf.matmul(final_state[1], weights['out']) + biases['out']   return results#loss functionpred = RNN(x, weights, biases)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))#optimizertrain_op = tf.train.AdamOptimizer(lr).minimize(cost)#accuracycorrect_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))#trainwith tf.Session() as sess:    sess.run(tf.global_variables_initializer())    step = 0    while step * batch_size < training_iters:        batch_xs, batch_ys = mnist.train.next_batch(batch_size)        batch_xs = batch_xs.reshape( [batch_size, n_steps, n_inputs] )        sess.run([train_op], feed_dict={x: batch_xs, y: batch_ys})        if step %20 == 0:            print sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys})        step += 1


0.210938
0.570312
0.757812
0.835938
0.882812
0.882812
0.890625
0.890625
0.898438
0.867188
0.914062
0.929688
0.914062
0.9375
0.898438
0.953125
0.945312
0.882812
0.945312
0.96875
0.9375
0.953125
0.992188
0.96875
0.976562
0.953125
0.945312
0.976562
0.953125
0.96875
0.945312
0.960938
0.945312
0.984375
0.953125
0.96875
0.992188
0.976562
0.960938
0.960938

原创粉丝点击