使用多层 LSTM API(4/7)
来源:互联网 发布:网络教育花钱能包过吗 编辑:程序博客网 时间:2024/06/14 20:30
这一次我们会让架构层次更深,使用LSTM多层结构
需要注意的是,在网络的每一层,我们都需要一个hidden state和一个cell state,
特别的是,输入到下一个LSTM层的输入,是那一个特定层的前一个状态,
隐藏的前一层的激活层也是
[这尼玛说的是啥?]
我们要 把每一层的states保存起来,将会有很多个LSTMTuples
,
为了方便,我们会用一个整的状态来代替之前的_current_cell_state
和 _current_hidden_state
_current_state = np.zeros((num_layers, 2, batch_size, state_size))
这里num_layers=3
2代表了2个states,cell和hidden
现在修改之前的代码
_total_loss, _train_step, _current_state, _predictions_series = sess.run( [total_loss, train_step, current_state, predictions_series], feed_dict={ batchX_placeholder: batchX, batchY_placeholder: batchY, # 这里 init_state: _current_state })
然后替换这里,换成一个 tensor
#cell_state = tf.placeholder(tf.float32, [batch_size, state_size])#hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])#init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)init_state = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
既然TensorFlow的多层api接受 作为LSTMTuples的state, 我们需要对state的数据结构动些手脚
对于状态里的每一层,我们创建一个LSTMTuple
,然后把它们放到一个元组(tuple
)里,在init_state
后面加上:
state_per_layer_list = tf.unstack(init_state, axis=0)rnn_tuple_state = tuple( [tf.nn.rnn_cell.LSTMStateTuple(state_per_layer_list[idx][0], state_per_layer_list[idx][1]) for idx in range(num_layers)])
然后forward pass的部分修改成
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state=rnn_tuple_state)
[cell]* num_layers
是对cell进行复制num次
多层LSTM一开始被用来创建一个单个的LSMTCell
然后在一个数组里复制这个cell
把它提供给MultiRNNCell的api调用
TensorFlow1.2的api调整
不能写成[cell]* num_layers
来复制成多层cell了,
会提示ValueError: Trying to share variable rnn ...
每个cell都要单独生成,在放在同一个list里
def lstm_cell(): return tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)stacked_lstm_cell = [lstm_cell() for _ in range(num_layers)]cell = tf.nn.rnn_cell.MultiRNNCell(stacked_lstm_cell, state_is_tuple=True)
阅读全文
0 0
- 使用多层 LSTM API(4/7)
- 使用TensorFlow的LSTM API(3/7)
- tensorflow笔记:多层LSTM代码分析
- tensorflow笔记:多层LSTM代码分析
- tensorflow笔记:多层LSTM代码分析
- TensorFlow入门(五)多层 LSTM 通俗易懂版
- TensorFlow 笔记(三):多层 LSTM代码详细介绍
- TensorFlow入门(五)多层 LSTM 通俗易懂版
- TensorFlow 实现多层 LSTM 的 MNIST 分类 + 可视化
- CNTK API文档翻译(10)——使用LSTM预测时间序列数据
- Caffe:LSTM使用
- 训练LSTM模型进行情感分类在IMDB数据集上,使用Keras API(Trains an LSTM model on the IMDB sentiment classification)
- VB中使用API创建深层目录(建立多层文件夹)
- CNTK API文档翻译(6)——对MNIST数据使用多层感知机
- tensorflow API学习——LSTM实现
- CNTK API文档翻译(11)——使用LSTM预测时间序列数据(物联网数据)
- Theano(4) LSTM
- 使用Torch nngraph实现LSTM
- 自定义进度条,可在进度条中添加节点
- 简单的静态通讯录的实现。
- 计算字符串最后一个单词的长度,单词以空格隔开
- poj 1182 食物链 【带权并查集】
- Maven学习之插件—maven-assembly-plugin
- 使用多层 LSTM API(4/7)
- git revert
- Redis 客户端运行
- 高斯模版生成代码
- textarea如何实现高度自适应?
- Android音频系统之音频框架
- 浅谈前端移动端页面开发(布局篇)
- 38-数字在排序数组中出现的次数
- 数据结构之二叉树