双向LSTM实现实现

来源:互联网 发布:云计算标准和应用大会 编辑:程序博客网 时间:2024/05/04 02:11

本文也就是用于图像分类,数据是minst数据,先看下tensorflow的版本,下面在代码低版本可能运行不了:


  1. In [1]: import tensorflow as tf
  2. In [2]: tf.__version__
  3. Out[2]: '1.2.1'


我用的动态rnn来实现,双向rnn跟一层的rnn实际代码上没有太大的区别,主要要用tf.concat(outputs, 2) 最后的结果,下面看下代码怎么实现:

  1. from tensorflow.examples.tutorials.mnist import input_data
  2. mnist=input_data.read_data_sets("../MNIST_data",one_hot=True)
  3. import tensorflow as tf
  4. import numpy as np
  5. from tensorflow.python.ops.control_flow_ops import constant_op
  6. #参数设置
  7. learning_rate=0.001
  8. training_iters=100000
  9. batch_size=128
  10. display_step=10
  11. #Network Parmeters
  12. n_input=28 # MNIST data input (img shape: 28*28)
  13. n_steps=28 # timesteps
  14. n_hidden=128 #hidden layer num of features
  15. n_class=10 #MNIST total class(0-9 digits)
  16. #tf graph input
  17. x=tf.placeholder(tf.float32,[None,n_steps,n_input])
  18. # Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
  19. istate_fw=tf.placeholder(tf.float32,[None,2*n_hidden])
  20. istate_bw=tf.placeholder(tf.float32,[None,2*n_hidden])
  21. y=tf.placeholder(tf.float32,[None,n_class])
  22. #define weights
  23. weights={
  24. 'hidden':tf.Variable(tf.random_normal([n_input,2*n_hidden])),
  25. 'out':tf.Variable(tf.random_normal([2*n_hidden,n_class]))
  26. }
  27. biases={
  28. 'hidden':tf.Variable(tf.random_normal([2*n_hidden,])),
  29. 'out':tf.Variable(tf.random_normal([n_class,]))
  30. }
  31. def BiRNN(_X,_isstate_fw,_istate_bw,_weights,_biases,_batch_size,_seq_len):
  32. _seq_len=tf.fill([_batch_size],constant_op.constant(_seq_len,dtype=tf.int64))
  33. # _X=tf.transpose(_X,[1,0,2])
  34. # _X=tf.reshape(_X,[-1,n_input])
  35. # _X=tf.matmul(_X,_weights['hidden'])+_biases['hidden']
  36. #
  37. # # Define lstm cells with tensorflow
  38. # lstm_fw_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)
  39. # # Backward direction cell
  40. # lstm_bw_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)
  41. # # Split data because rnn cell needs a list of inputs for the RNN inner loop
  42. # _X=tf.split(_X,n_steps,0)
  43. # # Get lstm cell output
  44. # outputs=tf.nn.static_bidirectional_rnn(lstm_fw_cell,lstm_bw_cell,_X,
  45. # initial_state_fw=_isstate_fw,
  46. # initial_state_bw=_iinputs must be a sequencestate_bw,
  47. # sequence_length=_seq_lesequence_length=_seq_lenn)
  48. # Define lstm cells with tensorflow
  49. lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)
  50. # Backward direction cell
  51. lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden,forget_bias=1.0)
  52. outputs,output_states= tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, _X,dtype=tf.float32)
  53. outputs=tf.concat(outputs, 2)
  54. outputs = tf.transpose(outputs, [1, 0, 2])
  55. return tf.matmul(outputs[-1],_weights['out'])+ _biases['out']
  56. pred=BiRNN(x,istate_fw,istate_bw,weights,biases,batch_size,n_steps)
  57. cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
  58. optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
  59. # Evaluate model
  60. correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
  61. accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
  62. init=tf.initialize_all_variables()
  63. with tf.Session() as sess:
  64. sess.run(init)
  65. step=1
  66. while step*batch_size<training_iters:
  67. batch_xs,batch_ys=mnist.train.next_batch(batch_size)
  68. batch_xs=batch_xs.reshape(batch_size,n_steps,n_input)
  69. sess.run(optimizer,feed_dict={
  70. x:batch_xs,y:batch_ys,
  71. istate_fw:np.zeros((batch_size,2*n_hidden)),
  72. istate_bw:np.zeros((batch_size,2*n_hidden))
  73. })
  74. if step % display_step == 0:
  75. # Calculate batch accuracy
  76. acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys,
  77. istate_fw: np.zeros((batch_size, 2 * n_hidden)),
  78. istate_bw: np.zeros((batch_size, 2 * n_hidden))})
  79. # Calculate batch loss
  80. loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys,
  81. istate_fw: np.zeros((batch_size, 2 * n_hidden)),
  82. istate_bw: np.zeros((batch_size, 2 * n_hidden))})
  83. print(
  84. "Iter " + str(step * batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + \
  85. ", Training Accuracy= " + "{:.5f}".format(acc))
  86. step += 1
  87. print(
  88. "Optimization Finished!")
  89. # Calculate accuracy for 128 mnist test images
  90. test_len = 128
  91. test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
  92. test_label = mnist.test.labels[:test_len]
  93. print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label,
  94. istate_fw: np.zeros((test_len, 2 * n_hidden)),
  95. istate_bw: np.zeros((test_len, 2 * n_hidden))}))


结果:

  1. Iter 93440, Minibatch Loss= 0.107338, Training Accuracy= 0.96875
  2. Iter 94720, Minibatch Loss= 0.043600, Training Accuracy= 0.98438
  3. Iter 96000, Minibatch Loss= 0.085884, Training Accuracy= 0.97656
  4. Iter 97280, Minibatch Loss= 0.113962, Training Accuracy= 0.96094
  5. Iter 98560, Minibatch Loss= 0.135330, Training Accuracy= 0.95312
  6. Iter 99840, Minibatch Loss= 0.130067, Training Accuracy= 0.97656
  7. Optimization Finished!
  8. Testing Accuracy: 0.984375
原创粉丝点击