tensorflow 实现rnn

来源:互联网 发布:帝国cms快速仿站 编辑:程序博客网 时间:2024/06/05 10:10

本文主要用于讨论官方自带的一个例子,rnn,用于多对一的情况,也就是用于图像分类,
先看下tensorflow的版本,下面在代码低版本可能运行不了:


  1. In [1]: import tensorflow as tf
  2. In [2]: tf.__version__
  3. Out[2]: '1.1.0-rc0'

代码:

  1. import tensorflow as tf
  2. from tensorflow.contrib import rnn
  3. from tensorflow.examples.tutorials.mnist import input_data
  4. mnist=input_data.read_data_sets("../MNIST_data",one_hot=True)
  5. lr=0.001
  6. training_iters=1000000
  7. batch_size=128
  8. n_inputs=28
  9. n_step=28
  10. n_hidden_unis=128
  11. n_classes=10
  12. x=tf.placeholder(tf.float32,[None,n_step,n_inputs])
  13. y=tf.placeholder(tf.float32,[None,n_classes])
  14. weights={
  15. 'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_unis])),
  16. 'out':tf.Variable(tf.random_normal([n_hidden_unis,n_classes]))
  17. }
  18. biases={
  19. 'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_unis,])),
  20. 'out':tf.Variable(tf.constant(0.1,shape=[n_classes,]))
  21. }
  22. def RNN(_X,weights,biases):
  23. _X = tf.transpose(_X, [1, 0, 2]) # permute n_steps and batch_size
  24. _X = tf.reshape(_X, [-1, n_inputs]) # (n_steps*batch_size, n_input)
  25. _X = tf.matmul(_X, weights['in']) + biases['in']
  26. lstm_cell = rnn.BasicLSTMCell(n_hidden_unis, forget_bias=1.0)
  27. _init_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
  28. _X = tf.split(_X, n_step,0 ) # n_steps * (batch_size, n_hidden)
  29. outputs, states = rnn.static_rnn(lstm_cell, _X, initial_state=_init_state)
  30. # Get inner loop last output
  31. return tf.matmul(outputs[-1], weights['out']) + biases['out']
  32. pred=RNN(x,weights,biases)
  33. cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
  34. train_op=tf.train.AdagradOptimizer(lr).minimize(cost)
  35. correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  36. accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
  37. init=tf.global_variables_initializer()
  38. with tf.Session() as sess:
  39. sess.run(init)
  40. step=0
  41. while step*batch_size<training_iters:
  42. batch_xs,batch_ys=mnist.train.next_batch(batch_size)
  43. batch_xs=batch_xs.reshape([batch_size,n_step,n_inputs])
  44. sess.run(train_op,feed_dict={
  45. x:batch_xs,
  46. y:batch_ys
  47. })
  48. if step%20==0:
  49. print(sess.run(accuracy,feed_dict={
  50. x: batch_xs,
  51. y: batch_ys
  52. }))
  53. step+=1

结果:

  1. 0.921875
  2. 0.953125
  3. 0.96875
  4. 0.953125
  5. 0.960938
  6. 0.976562
  7. 0.96875
  8. 0.960938
  9. 0.960938
  10. 0.945312
  11. 0.960938
原创粉丝点击