自学笔记:LSTM理论联系实际的TENSORFLOW代码研究, state和ouput的数据结构

来源:互联网 发布:淘宝网长袖连衣裙 编辑:程序博客网 时间:2024/05/21 14:04

看了一些讲述LSTM的原理,基本上搞清楚了,不过需要理解代码还有一段路要走。

网上找了一个最简单的示例,不过无法在tensorflow1.3运行,花了一些时间,解决了兼容问题,下面的代码可以运行。

#运行版本,tensorflow1.4 
#源代码从网上COPY的示例改编,原来的示例无法在tensorflow1.3以上运行
import tensorflow as tf;  
import numpy as np;  
  
units_num = 10  #隐藏层节点
batch_size = 2    #训练批次


X = tf.random_normal(shape=[batch_size,5,7], dtype=tf.float32)  
#X = tf.reshape(X, [-1, 5, 6])  


def cell():  #一定要定义成函数,否则出错, tensorflow1.X不兼容
    return tf.nn.rnn_cell.BasicLSTMCell(units_num) #10代表10个节点
    #也可以换成别的,比如GRUCell,BasicRNNCell等等
  
lstm_multi = tf.nn.rnn_cell.MultiRNNCell([cell() for _ in range(2)],  state_is_tuple=True)


#mlstm_cell = rnn.MultiRNNCell([lstm_cell() for _ in range(layer_num)],state_is_tuple=True)
  
state = lstm_multi.zero_state(batch_size, tf.float32)  
output, state = tf.nn.dynamic_rnn(lstm_multi, X, initial_state=state, time_major=False)  
finaloutput = output[-1]   #取最后一个output
statec= state[-1].c             #取最后一个state
stateh = state[-1].h
with tf.Session() as sess:  
    sess.run(tf.initialize_all_variables())
    print('----------------output--------------------')
    print(sess.run(output))         #2*5*10
    print('----------------finaloutput--------------------')    
    print(sess.run(finaloutput))  #5*10
    print('----------------state--------------------')    
    print(sess.run(state))              #2*4*10
    print('-----------------statec-------------------')
    print(sess.run(statec))           #2*10
    print('----------------stateh--------------------')    
    print(sess.run(stateh))          #2*10
    
    '''
    WARNING:tensorflow:From C:\Program Files\Python36\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
----------------output--------------------
[[[-0.00637375 -0.02218931 -0.00364614  0.00871621  0.01000606  0.00688382
    0.0173985  -0.01417144 -0.00514521  0.00809157]
  [-0.02739093  0.00731545 -0.01414015  0.00776639  0.01201528 -0.00530205
   -0.00373759 -0.00243396  0.01967375  0.0176452 ]
  [-0.05032041  0.01292216 -0.01505091  0.00758813  0.00881143  0.0026364
    0.00501611 -0.01001035  0.01516667  0.03455975]
  [-0.03608039 -0.01159106 -0.00147149  0.01253749  0.01457058  0.00761829
    0.0266578  -0.0391868  -0.01949262  0.03059825]
  [-0.01878959 -0.0512354  -0.00916542  0.0181381   0.01850957 -0.00772178
    0.05093804 -0.04717548 -0.03797766  0.00213585]]


 [[-0.00301182 -0.04609392 -0.04139089 -0.00270795 -0.01418001 -0.02598612
    0.03196315  0.01515245  0.00444449 -0.02480585]
  [ 0.0216604  -0.07657862 -0.0436976  -0.00264414 -0.01624844 -0.0388252
    0.04456366 -0.00641088 -0.00966516 -0.02996463]
  [ 0.03529613 -0.06516693 -0.038265   -0.0112583  -0.01810452 -0.03949273
    0.0325823  -0.00466524 -0.00729331 -0.03314594]
  [ 0.05150301 -0.08670978 -0.06885949 -0.02202794 -0.05140235 -0.05352144
    0.03970223  0.02086095 -0.00696997 -0.06665371]
  [ 0.05955734 -0.06664775 -0.06081596  0.00371452 -0.03970711 -0.02188079
    0.03479016  0.03526397  0.00916961 -0.06597719]]]
----------------finaloutput--------------------
[[-0.04641263  0.01628485 -0.03593837 -0.01126753 -0.02915012 -0.00858944
  -0.00375725  0.03282307  0.01455755 -0.00436677]
 [-0.07043284  0.02731676 -0.04114634 -0.02013678 -0.04268811 -0.02129938
   0.00629975  0.02324007  0.00693215  0.01216679]
 [-0.07693886  0.01834853 -0.04679859 -0.03363486 -0.0439391  -0.06869462
   0.00808731  0.0087936   0.00530997  0.02094404]
 [-0.05964042 -0.02323203 -0.01288006 -0.02311298 -0.00924513 -0.1003315
   0.02314057 -0.02797622 -0.01521164  0.03499829]
 [-0.05991254 -0.04570793 -0.02727396 -0.02612418 -0.00901282 -0.12065098
   0.02534271 -0.00855778 -0.00444907  0.01586689]]
----------------state--------------------
(LSTMStateTuple(c=array([[-0.66914773,  0.07839034,  0.33994618,  0.8456924 ,  0.10953232,
        -0.39234626,  0.13399592, -0.15174234,  0.00130345, -0.30228698],
       [-0.38686785, -0.13705038, -0.1424486 ,  0.0899242 ,  0.24239531,
         0.08536939, -0.27032876, -0.24645516, -0.00084634,  0.18780154]], dtype=float32), h=array([[ -2.64760315e-01,   4.65154909e-02,   1.34974703e-01,
          3.55697572e-01,   3.73032168e-02,  -2.32693255e-01,
          7.82607347e-02,  -6.97162896e-02,   3.86156840e-04,
         -1.59545243e-01],
       [ -1.43306926e-01,  -7.28544220e-02,  -6.16334565e-02,
          4.14225310e-02,   1.12995833e-01,   4.02112342e-02,
         -1.64336443e-01,  -1.15731291e-01,  -2.68079224e-04,
          8.61014202e-02]], dtype=float32)), LSTMStateTuple(c=array([[-0.12525879, -0.02339883, -0.15596688,  0.09201161, -0.13064483,
         0.1607396 ,  0.14479131,  0.06554902, -0.0195029 , -0.08419077],
       [-0.02751132,  0.08955061,  0.07464057,  0.07965946,  0.10505884,
         0.16445892, -0.04261293, -0.04892696,  0.03463147,  0.07745198]], dtype=float32), h=array([[-0.06212955, -0.01064036, -0.08202581,  0.04435679, -0.06457365,
         0.07770626,  0.0736967 ,  0.03266109, -0.00938224, -0.04352244],
       [-0.01347394,  0.04540392,  0.03609779,  0.04107576,  0.05197879,
         0.08114899, -0.02111933, -0.02465794,  0.01692747,  0.03863694]], dtype=float32)))
-----------------statec-------------------
[[-0.14902619  0.21614444 -0.05753845 -0.07061018 -0.17564972 -0.08847831
  -0.07436903  0.09744829  0.07963906  0.06430003]
 [ 0.09381064 -0.0832063   0.00594144  0.05869648  0.03069258  0.02718895
   0.00898496 -0.03545038  0.04609191  0.00290151]]
----------------stateh--------------------
[[-0.00620752  0.03671843  0.07395449  0.00736501  0.0685369   0.03242806
  -0.03287118 -0.08238971 -0.01176278  0.08281735]
 [ 0.04372304 -0.07892949 -0.01827243 -0.02095079 -0.01729526 -0.05310233
   0.02035648 -0.05144129 -0.0138404  -0.01711897]]


    '''

总结:

输出维度是10维,输入的维度2*5*7 输出变成2*5*10

output是所有batch的数据2*5*10

state也是所有batch的数据2*10 批次乘以输出节点

state数据分为state_c和state_h,分别对应的含义如图的Ht和Ht的数据:



state_c和state_h分别是2*10的数组,为什么是2*10而不是1*10,如图猜测是state_h一份传给输出output,一份传给下一层的状态state_h, state_c一份传给下一层的state_c,一份传给下一层的state_h

原创粉丝点击