TensorFlow搭建RNN(2/7) 使用TensorFlow的RNN API
来源:互联网 发布:新网域名过户到阿里云 编辑:程序博客网 时间:2024/06/05 08:47
这一篇文章是TensorFlow搭建RNN(1/7) 简单案例的后续文章,
前一篇文章里,我们从零建立了一个RNN,手动建立计算图,现在我们用TensorFlow原生API来简化我们的代码。
计算图的简单创建
inputs_series = tf.unstack(batchX_placeholder, axis=1) labels_series = tf.unstack(batchY_placeholder, axis=1) for current_input in inputs_series: current_input = tf.reshape(current_input, [batch_size, 1]) input_and_state_concatenated = tf.concat([current_input, current_state],1) # Increasing number of columns next_state = tf.tanh(tf.matmul(input_and_state_concatenated, W) + b) states_series.append(next_state) current_state = next_state
把之前的代码(上面)换成下面的,
inputs_series = tf.split(batchX_placeholder, truncated_backprop_length, 1)labels_series = tf.unstack(batchY_placeholder, axis=1)cell = tf.nn.rnn_cell.BasicRNNCell(state_size)states_series, current_state = tf.contrib.rnn.static_rnn(cell, inputs_series, initial_state = init_state)
还有,你可以把之前的权重和偏置矩阵W和b的声明部分也移除了,
这些都隐藏在RNN的api里面的了
看看这次的变化:
cell
cell = tf.nn.rnn_cell.BasicRNNCell(state_size)
看看之前W、b的定义:
W = tf.Variable(np.random.rand(state_size+1, state_size), dtype=tf.float32)b = tf.Variable(np.zeros((1,state_size)), dtype=tf.float32)
观察,W和b都只有一个可变参数,就是state_size
,
现在把 W和b放进了cell里面,传入state_size
x_inputs
inputs_series = tf.split(batchX_placeholder, truncated_backprop_length, 1)
用split
代替了unstack
, split
沿着axis=1将tensor分解成更小的tensors,
这里inputs_series的shape是(5,1)
而在之前的代码中,unstack把最后一个维度移除了,shape为(5,)
,
所以我们才又在for循环中reshape一次,把(5,)
转成(5,1)
tf.contrib.rnn.static_rnn
原文tf.nn.rnn
,现在替换为tf.contrib.rnn.static_rnn
代替了for循环,
把inputs和cell结合生成了states的序列,
返回的数据也和之前的一样:states_series, current_state
下一步
下一篇我们将用LSTM的架构来完善RNN。
虽然这个案例比较简单,但我们的目的是学习TensorFlow。
原文来自medium:
https://medium.com/@erikhallstrm/tensorflow-rnn-api-2bb31821b185
本文根据tensorflow 1.2的api修改了代码
阅读全文
0 0
- TensorFlow搭建RNN(2/7) 使用TensorFlow的RNN API
- TensorFlow搭建RNN(1/7) 简单案例
- Tensorflow RNN源代码解析笔记2:RNN的基本实现
- 解读tensorflow之rnn
- 解读tensorflow之rnn
- tensorflow 循环神经网络RNN
- tensorflow之RNN
- tensorflow RNN实例
- TensorFlow MNIST RNN LSTM
- tensorflow rnn阅读笔记
- tensorflow-rnn代码解读
- tensorflow 实现rnn
- tensorflow 循环神经网络RNN
- Tensorflow-rnn(mnist分类)
- tensorflow example rnn
- Tensorflow-LSTM RNN 例子
- 解读tensorflow之rnn
- Tensorflow 实践RNN(一)
- ARP协议详解
- 标准作息表
- 【贪心】Stripies POJ 1862
- 【HNOI2016模拟4.14】A
- Linux--dd命令
- TensorFlow搭建RNN(2/7) 使用TensorFlow的RNN API
- Qt创建桌面快捷方式和删除桌面快捷方式
- 正则表达式
- Redis--持久化
- 总结
- 【模板】迪杰斯特拉的优先队列优化
- 第一篇博客
- 洛谷P1278 单词游戏
- Spring-Data-Redis集群配置和RedisTemplate用法