RNN简介
来源:互联网 发布:2017网络市场专项行动 编辑:程序博客网 时间:2024/06/05 06:54
1. RNN的提出背景
对于序列类型的数据,比如翻译问题,时间序列预测问题,由于前后样本之间有顺序之间的先后关系,而传统的神经网络无法体现这一关系,因此出现了RNN(Recurrent Neural Network)。RNN会记住之前的输入值,因此对于RNN而言,即时后面有相同的输入,输出值也会不同,这是与传统的NN不同的地方。序列数据的类型有:
- one to one: Vanilla Neural Networks,最简单的BP网络
- one to many: Image Captioning,可将图像转换成文字
- many to one:Sentiment Classification,情感分类
- many to many-1: Machine Translation,机器翻译
- many to many-2:Video classification,视频分类
2. RNN的基本结构
对于RNN神经元,其输出结果考虑的东西除了输入值之外,还会考虑存在内存当中的其他值,这些值是先前计算的结果:
首先在初始的时候,会给Memory单元的
例如: 令所有的权值都是1,
因此输出样本的顺序会对结果产生影响。因此下一时刻的输出值会受到前面所有时刻的输入值的影响:
如果令
例如对于词汇来将,每个单词都是一个输入的x,对于其他时间序列数据,每个时间点的输入都是一个输入样本:
更一般的情况,会将输出值存入memory中:
3. Bidirectional RNN(双向RNN)
将输入的数字,从前到后训练一个RNN,从后往前训练一个RNN。对于单向的RNN,由于只能看到前面已经输入的样本数据,因此无法看到全局。而对于双向RNN,既能看到全面的样本数据,又能看到后面的样本数据,准确性更好。
4. RNN中的梯度爆炸(gradient explode)和梯度消失(gradient vanishing)
先看下面一个简单的例子:
假设
假设
即每一时刻的损失之和。RNN的训练过程其实就是对
在
因此对于
因此任意时刻损失对于
对于
因此:
如果激活函数为
如果激活函数为
由上图可以看出
5 LSTM的基本结构
每个LSTM神经元有4个
对于传统的RNN神经元,
)
Z是外界的input,其中哦那个的3个gate分别为input gate, forget gate, output gate,控制参数分别为
LSTM与传统的神经网络相同,只需将传统的网络中的神经元换成
因此对于一个序列
通常情况下LSTM的神经网络图为
)
假设该网络有
单个Cell的计算过程为:
整个LSTM的计算过程为:
以上只是最简单的LSTM,实际使用过程中会将
多层的LSTM结构为:
6. 为何LSTM能解决Gradient Vanishing问题
实际工作过程当中,RNN的训练Loss变化为:
即图中绿线所示,蓝线为理想状态下Loss变化图。
训练过程中,Loss与参数变化示意图为:
可以看出,图中的Loss会剧烈震荡,因此在通过Gradient Descent参数参数的过程当中,由于learning rate的原因,导致参数跳动太大。因此一般采用clip的方法,即当Gradient大于某个阈值的时候,就以该阈值作为此时的Gradient。
RNN能产生Gradient Vanishing主要是由于RNN本身的特征产生的。对于传统的RNN来讲,假设有以下的一个简单的例子:
如果输入序列为
而对于LSTM来讲,在处理
另一方面,RNN无法解决长期记忆问题,假设我们试着去预测“I grew up in France… I speak fluent French”最后的词。当前的信息建议下一个词可能是一种语言的名字,但是如果我们需要弄清楚是什么语言,我们是需要先前提到的离当前位置很远的 France 的上下文的。这说明相关信息和当前预测位置之间的间隔就肯定变得相当的大。不幸的是,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。但是LSTM因为拥有
其他的解释LSTM解决
1. LSTM如何来避免梯度弥散和梯度爆炸?
2. 为什么相比于RNN,LSTM在梯度消失上表现更好?
3. LSTM如何解决梯度消失问题
7. LSTM解决问题举例:
7.1. 机器翻译
但此时翻译会一直继续下去,因此需要定义一个终止符号:
其余都是通过
7.2. 语音辨识:
训练结果为:
7.3. 聊天机器人:
通过不断收集对话的
7.4. visual question answer
7.4. 英语听力测试:
6 LSTM的TensorFlow简单实现
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datafrom tensorflow.contrib import rnnmnist = input_data.read_data_sets("/tmp/data/", one_hot = True)hm_epochs = 3n_classes = 10batch_size = 128chunk_size = 28n_chunks = 28rnn_size = 128x = tf.placeholder('float', [None, n_chunks,chunk_size])y = tf.placeholder('float')def recurrent_neural_network(x): layer = {'weights':tf.Variable(tf.random_normal([rnn_size,n_classes])), 'biases':tf.Variable(tf.random_normal([n_classes]))} x = tf.transpose(x, [1,0,2]) x = tf.reshape(x, [-1, chunk_size]) x = tf.split(x, axis=0, num_or_size_splits = n_chunks) lstm_cell = rnn.BasicLSTMCell(rnn_size,state_is_tuple=True) outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32) output = tf.matmul(outputs[-1],layer['weights']) + layer['biases'] return outputdef train_neural_network(x): prediction = recurrent_neural_network(x) cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y)) optimizer = tf.train.AdamOptimizer().minimize(cost) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(hm_epochs): epoch_loss = 0 for _ in range(int(mnist.train.num_examples / batch_size)): epoch_x, epoch_y = mnist.train.next_batch(batch_size) epoch_x = epoch_x.reshape((batch_size, n_chunks, chunk_size)) _, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y}) epoch_loss += c print('Epoch', epoch, 'completed out of', hm_epochs, 'loss:', epoch_loss) correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct, 'float')) print('Accuracy:', accuracy.eval({x: mnist.test.images.reshape((-1, n_chunks, chunk_size)), y: mnist.test.labels}))train_neural_network(x)
- RNN简介
- RNN简介
- RNN简介
- RNN 简介
- 递归神经网络(RNN)简介
- 递归神经网络(RNN)简介
- RNN 之LSTM简介
- 循环神经网络RNN简介
- 递归神经网络(RNN)简介
- RNN, LSTM简介
- RNN(recurrent neural networks)简介
- 循环神经网络教程第一部分-RNN简介
- 递归(循环)神经网络(RNN)简介
- RNN
- rnn
- RNN
- RNN
- RNN
- 面向对象编程
- 古文觀止卷八_送董邵南序_韓愈
- 数据结构实验之查找五:平方之哈希表
- web开发编码问题
- android studio教程-创建第一个项目Hello World
- RNN简介
- 微信公众平台开发视频公开课第1讲-基础入门
- SpringMVC相关
- 《微信公众平台应用开发:方法、技巧与案例》火热预售中...
- 时间插件测试
- 即将陆续推出微信公众平台开发视频教程
- 剑指offer(十)矩形覆盖
- 参与CSDN社区问答活动“基于Java的微信公众平台开发”赢签名赠书
- 《趣学算法》目录及签名版