第十一课 tensorflow RNN原理及解析
来源:互联网 发布:淘宝详情页是什么意思 编辑:程序博客网 时间:2024/05/21 17:45
RNN
原理解析
RNN的一层结构图如下:
- W: state 权重
- U: 输入权重
- V: 输出权重
- xt: t个step的输入
- ht: rnn cell 隐层, 也有的叫st (状态层)
- ot: 最终的输出
- yt: 经过soft max之后的分类结果
数学关系:
关于维度:
* ht维度: 是隐层的数量,也是自定义的, shape=[hn✖️1]
* W维度: shape=[hn, hn],才能保证W✖️ht-1的维度与ht的维度一样[hn✖️1]
* xt维度: 是embeding时候自定义的,shape=[xn✖️1]
* U维度: shape=[hn, xn],保证U✖️xt 与 ht的维度是一样的 [hn, 1]
* ot: 与xt的输出维度是一样的, 所以V必须是[xn, hn]维度.理论上也可以不一样,对于多层的RNN来说,只要保证最后一层ot与xt的输出是一样的就可以了。所以V其实可以是任意的[vn, hn]这样就可以了,在最后一层变成[xn, hn]就可以了,保证与xt一样的维度就可以了.
rnn cell的实现
通过上面的原理介绍,如果实现一个rnn cell,函数描述如下:
在tensorflow对应的类是: tf.nn.rnn_cell.BasicRNNCell.
BasicRNNCell
init参数描述:
- num_units: 隐层单元的数量,这是自己定义的。也就是前面维度中所说的hn, 隐层的单元数是自己定义的.
- activation: 激活函数,双曲正切
- reuse: True, 表示共享变量,多个cell 是共享权重的.
call参数描述:
- inputs: 也就是xt
- state: 也就是前一个ht-1, 注意ht-1的维度要与num_units吻合因为二者本来就是一个.
- ht, ht: 返回的都是ht,与前面描述 中返回的是ot, ht,因为ot也是ht变换得来的ot=C+V✖️ht,所以这里返回是两个ht.
demo:
import tensorflow as tfimport numpy as npimport loggingfrom tensorflow.python.ops import array_opslogging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s [%(filename)s:%(lineno)d]" )
def test_rnn_cell(): num_units = 2 state_size = num_units # state_size也就是ht的size一定要与num_units是一样的 batch_size = 1 input_size = 4 x = array_ops.zeros([batch_size, input_size]) m = array_ops.zeros([batch_size, state_size]) logging.info('x type: ' + str(type(x)) + ': ' + str(x.shape)) logging.info('m type: ' + str(type(m)) + ': ' + str(m.shape)) # m = (array_ops.zeros([batch_size]), array_ops.zeros([batch_size])) with tf.Session() as sess: g, out_m = tf.nn.rnn_cell.BasicRNNCell(num_units)(x, m) sess.run([tf.global_variables_initializer()]) # g_result == out_m_result 二者是同一个 g_result, out_m_result = sess.run([g, out_m], {x.name: 1 * np.ones([batch_size, input_size]), m.name: 0.1 * np.ones([batch_size, state_size])}) logging.info('g_result: ' + str(g_result)) logging.info('out_m_result: ' + str(out_m_result))
test_rnn_cell()
[2017-10-13 15:37:41,889] root:INFO: x type: <class 'tensorflow.python.framework.ops.Tensor'>: (1, 4) [<ipython-input-4-0e1f199b1a5b>:10][2017-10-13 15:37:41,890] root:INFO: m type: <class 'tensorflow.python.framework.ops.Tensor'>: (1, 2) [<ipython-input-4-0e1f199b1a5b>:11][2017-10-13 15:37:42,042] root:INFO: g_result: [[-0.7603032 0.54075867]] [<ipython-input-4-0e1f199b1a5b>:22][2017-10-13 15:37:42,043] root:INFO: out_m_result: [[-0.7603032 0.54075867]] [<ipython-input-4-0e1f199b1a5b>:23]
阅读全文
0 0
- 第十一课 tensorflow RNN原理及解析
- Tensorflow源码解析系列--RNN
- Tensorflow RNN源代码解析笔记2:RNN的基本实现
- RNN的原理与TensorFlow代码实现
- Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
- Tensorflow学习笔记--RNN精要及代码实现
- RNN入门详解及TensorFlow源码实现--深度学习笔记
- 解读tensorflow之rnn
- 解读tensorflow之rnn
- tensorflow 循环神经网络RNN
- tensorflow之RNN
- tensorflow RNN实例
- TensorFlow MNIST RNN LSTM
- tensorflow rnn阅读笔记
- tensorflow-rnn代码解读
- tensorflow 实现rnn
- tensorflow 循环神经网络RNN
- Tensorflow-rnn(mnist分类)
- java.ulti中的接口和抽象类梳理与分析
- 探讨margin-top的bug
- 常用的git命令
- arse Error at line 58 column 17: The content of element type "struts-config" must match "(display-na
- springmvc中如何配置控制台输出日志
- 第十一课 tensorflow RNN原理及解析
- 面试记录第二十节——(MVP讲解)
- 最新版管家婆辉煌版普及版II TOP+ V12.71单机、网络、门店破解
- C# 导出 Excel 和相关打印设置
- git 操作的一些方式
- 使用Apache的ab工具进行压力测试
- Communications link failure
- redis的持久化--快照持久化(SNAPSHOTTING)
- sort