TensorFlow RNN 相关类与方法
来源:互联网 发布:centos 搜索文件内容 编辑:程序博客网 时间:2024/06/06 05:56
RNN又包括LSTM, GRU等.
1. RNN
1.1 cell
tensorflow.python.layers.base.Layer
类, 表示网络中的一层.tensorflow.python.ops.rnn_cell_impl.RNNCell(base_layer.Layer)
抽象类, 表示RNN中的一层. 叫作cell更形象.RNNCell#zero_state(self, batch_size, dtype)
返回 zero-filled state tensor(s).
1.2 MultiRNNCell
tf.nn.rnn_cell.MultiRNNCell(RNNCell)
类.
RNN cell composed sequentially of multiple simple cells.__init__(self, cells, state_is_tuple=True)
构造函数.cells
: list of RNNCells that will be composed in this order.
1.3 dynamic_rnn
tf.nn.dynamic_rnn(cell, inputs, ...)
Creates a recurrent neural network specified by RNNCell cell. Performs fully dynamic unrolling ofinputs
.
返回的是(outputs, state), output_tensor 的 shape是(BATCH_SIZE, TIME_STEPS, NUM_UNITS)
. 像预测 sin x 之类的问题, 关注的是 最后一个时刻的输出结果, 那么可以:
""" before: output_rnn.shape = (5,10,8), namely (BATCH_SIZE, TIME_STEPS, NUM_UNITS) after: output_rnn.shape = (5,8) , namely (BATCH_SIZE , NUM_UNITS)""" output_rnn=output_rnn[:,-1,:]
tf.nn.rnn_cell.DropoutWrapper(RNNCell)
类. 在 cell 的输入与输出中添加 dropout 操作.__init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, state_keep_prob=1.0, ...)
构造函数.
keep_prob 即1- dropout_prob
.tf.nn.dropout(x, keep_prob, ...)
x
: a tensor.
2. tensorflow 处理input与output
涉及到 TIME_STEP 的地方, tf中一般这么处理:
X = tf.placeholder(tf.float32, shape=[None, TIME_STEPS, INPUT_SIZE])Y = tf.placeholder(tf.float32, shape=[None, OUTPUT_SIZE])def lstm_model(input_tensor,out_tensor): lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(NUM_UNITS, state_is_tuple=True) # 需要将tensor转成2维进行计算,计算后的结果作为隐藏层的输入 input_tensor=tf.reshape(input_tensor, [-1, INPUT_SIZE]) w_in=tf.get_variable('weight_input',shape=[INPUT_SIZE,NUM_UNITS],initializer=tf.truncated_normal_initializer()) # 注意 shape=[NUM_UNITS,] 与 shape=[NUM_UNITS,1] 的严重不同! b_in=tf.get_variable('bias_input',shape=[NUM_UNITS, ],initializer=tf.zeros_initializer()) input_rnn = tf.matmul(input_tensor, w_in) + b_in # 将tensor转成3维,作为lstm cell的输入 input_rnn = tf.reshape(input_rnn, [-1, TIME_STEPS, NUM_UNITS]) # state 的大小一般也是 BATCH_SIZE 的大小 init_state = cell.zero_state(BATCH_SIZE, dtype=tf.float32) # output_rnn.shape=(BATCH_SIZE, TIME_STEPS, NUM_UNITS) output_rnn, final_states = tf.nn.dynamic_rnn(lstm_cell , input_rnn,initial_state=init_state, dtype=tf.float32) # 在本问题中只关注最后一个时刻的输出结果,output_rnn.shape=(BATCH_SIZE , NUM_UNITS) output_rnn=output_rnn[:,-1,:] # 创建一个全连接层,因为是回归问题, 输出的维度为1,None指的是不使用激活函数 predictions = tf.contrib.layers.fully_connected(output_rnn, 1, None)
3. LSTM
tensorflow.python.ops.rnn_cell_impl.BasicLSTMCell(RNNCell)
类. 通常通过tf.nn.rnn_cell.BasicLSTMCell
引用.__init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None)
构造函数.
4.GRU
tensorflow.python.ops.rnn_cell_impl.GRUCell(RNNCell)
类.
阅读全文
0 0
- TensorFlow RNN 相关类与方法
- TensorFlow CNN 相关类与方法
- RNN的原理与TensorFlow代码实现
- rnn 相关
- tensorflow中RNN时间步内置的方法
- TensorFlow 常用类与方法
- RNN聊天机器人与Beam Search [Tensorflow Seq2Seq]
- tensorflow中的RNN与LSTM函数异同点分析
- 解读tensorflow之rnn
- 解读tensorflow之rnn
- tensorflow 循环神经网络RNN
- tensorflow之RNN
- tensorflow RNN实例
- TensorFlow MNIST RNN LSTM
- tensorflow rnn阅读笔记
- tensorflow-rnn代码解读
- tensorflow 实现rnn
- tensorflow 循环神经网络RNN
- RabbitMQ核心概念篇
- git 将本地仓库推送到github仓库
- secureCRT 常用命令
- 1808:最长字符串
- cef3接口介绍
- TensorFlow RNN 相关类与方法
- MyBatis缓存技术
- Hibernate事务与并发处理
- c++ fill函数,fill与memset函数的区别
- UI18-使用NSJSONSerialization方法解析JSON
- 【IOS】IOS常用第三方库总结
- cacti监控部署——网络流量监控
- 欢迎使用CSDN-markdown编辑器
- Python之路【第二十三篇】爬虫