第十一课 tensorflow RNN原理及解析

来源:互联网 发布:淘宝详情页是什么意思 编辑:程序博客网 时间:2024/05/21 17:45

RNN

原理解析

RNN的一层结构图如下:

rnn cell link

  • W: state 权重
  • U: 输入权重
  • V: 输出权重
  • xt: t个step的输入
  • ht: rnn cell 隐层, 也有的叫st (状态层)
  • ot: 最终的输出
  • yt: 经过soft max之后的分类结果

数学关系:
rnn 公式
关于维度:
* ht维度: 是隐层的数量,也是自定义的, shape=[hn✖️1]
* W维度: shape=[hn, hn],才能保证W✖️ht-1的维度与ht的维度一样[hn✖️1]
* xt维度: 是embeding时候自定义的,shape=[xn✖️1]
* U维度: shape=[hn, xn],保证U✖️xt 与 ht的维度是一样的 [hn, 1]
* ot: 与xt的输出维度是一样的, 所以V必须是[xn, hn]维度.理论上也可以不一样,对于多层的RNN来说,只要保证最后一层ot与xt的输出是一样的就可以了。所以V其实可以是任意的[vn, hn]这样就可以了,在最后一层变成[xn, hn]就可以了,保证与xt一样的维度就可以了.

rnn cell的实现

通过上面的原理介绍,如果实现一个rnn cell,函数描述如下:

名称 类别 描述 对应变量 xt 输入参数 当前时刻的embeding变量 xt ht-1 输入参数 前一个时刻的隐层变量输出 ht-1 ot 返回值 返回的target, 与x的维度是一样的 ot ht 返回值 当前时刻产生的新的隐层变量 ht

在tensorflow对应的类是: tf.nn.rnn_cell.BasicRNNCell.

BasicRNNCell

init参数描述:

  • num_units: 隐层单元的数量,这是自己定义的。也就是前面维度中所说的hn, 隐层的单元数是自己定义的.
  • activation: 激活函数,双曲正切
  • reuse: True, 表示共享变量,多个cell 是共享权重的.

call参数描述:

  • inputs: 也就是xt
  • state: 也就是前一个ht-1, 注意ht-1的维度要与num_units吻合因为二者本来就是一个.
  • ht, ht: 返回的都是ht,与前面描述 中返回的是ot, ht,因为ot也是ht变换得来的ot=C+V✖️ht,所以这里返回是两个ht.

demo:

import tensorflow as tfimport numpy as npimport loggingfrom tensorflow.python.ops import array_opslogging.basicConfig(        level=logging.INFO,        format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s [%(filename)s:%(lineno)d]"    )
def test_rnn_cell():    num_units = 2    state_size = num_units # state_size也就是ht的size一定要与num_units是一样的    batch_size = 1    input_size = 4    x = array_ops.zeros([batch_size, input_size])    m = array_ops.zeros([batch_size, state_size])    logging.info('x type: ' + str(type(x)) + ': ' + str(x.shape))    logging.info('m type: ' + str(type(m)) + ': ' + str(m.shape))    # m = (array_ops.zeros([batch_size]), array_ops.zeros([batch_size]))    with tf.Session() as sess:        g, out_m = tf.nn.rnn_cell.BasicRNNCell(num_units)(x, m)        sess.run([tf.global_variables_initializer()])        # g_result == out_m_result 二者是同一个        g_result, out_m_result = sess.run([g, out_m],                                          {x.name: 1 * np.ones([batch_size, input_size]),                                           m.name: 0.1 * np.ones([batch_size, state_size])})        logging.info('g_result: ' + str(g_result))        logging.info('out_m_result: ' + str(out_m_result))
test_rnn_cell()
[2017-10-13 15:37:41,889] root:INFO: x type: <class 'tensorflow.python.framework.ops.Tensor'>: (1, 4) [<ipython-input-4-0e1f199b1a5b>:10][2017-10-13 15:37:41,890] root:INFO: m type: <class 'tensorflow.python.framework.ops.Tensor'>: (1, 2) [<ipython-input-4-0e1f199b1a5b>:11][2017-10-13 15:37:42,042] root:INFO: g_result: [[-0.7603032   0.54075867]] [<ipython-input-4-0e1f199b1a5b>:22][2017-10-13 15:37:42,043] root:INFO: out_m_result: [[-0.7603032   0.54075867]] [<ipython-input-4-0e1f199b1a5b>:23]
原创粉丝点击