单向RNN和双向RNN在mnist数据集上的分类实验

来源:互联网 发布:可以看央视的网络电视 编辑:程序博客网 时间:2024/05/22 06:59

RNN用于图像分类思路很奇特,不明觉厉,具体可以参考相关论文,rnn和birnn的实验:

#!/usr/bin/env python# -*- coding: utf-8 -*-# created by fhqplzj on 2017/06/19 下午10:28from __future__ import print_functionimport tensorflow as tffrom tensorflow.contrib import rnnfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('/Users/fhqplzj/github/TensorFlow-Examples/examples/3_NeuralNetworks/data',                                  one_hot=True)learning_rate = 0.001training_iters = 100000batch_size = 128display_step = 10n_input = 28n_steps = 28n_hidden = 128n_classes = 10x = tf.placeholder(tf.float32, [None, n_steps, n_input])y = tf.placeholder(tf.float32, [None, n_classes])weights = {    'out1': tf.Variable(tf.random_normal([n_hidden, n_classes])),    'out2': tf.Variable(tf.random_normal([2 * n_hidden, n_classes]))}biases = {    'out': tf.Variable(tf.random_normal([n_classes]))}def RNN(x, weights, biases):    x = tf.unstack(x, n_steps, 1)    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)    return tf.matmul(outputs[-1], weights['out1']) + biases['out']def BiRNN(x, weights, biases):    x = tf.unstack(x, n_steps, 1)    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)    outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)    return tf.matmul(outputs[-1], weights['out2']) + biases['out']for func in (RNN, BiRNN):    print(func.func_name.center(100, '+'))    pred = func(x, weights, biases)    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))    init = tf.global_variables_initializer()    with tf.Session() as sess:        sess.run(init)        step = 1        while step * batch_size < training_iters:            batch_x, batch_y = mnist.train.next_batch(batch_size)            batch_x = batch_x.reshape((-1, n_steps, n_input))            sess.run(optimizer, feed_dict={                x: batch_x,                y: batch_y            })            if step % display_step == 0:                acc = sess.run(accuracy, feed_dict={                    x: batch_x,                    y: batch_y                })                loss = sess.run(cost, feed_dict={                    x: batch_x,                    y: batch_y                })                print('acc={:.6f},cost={:.6f}'.format(acc, loss))            step += 1        print('Optimization Finished!')        total_len = 128        test_x, test_y = mnist.test.next_batch(total_len)        test_x = test_x.reshape((-1, n_steps, n_input))        print(sess.run(accuracy, feed_dict={            x: test_x,            y: test_y        }))



原创粉丝点击