Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
来源:互联网 发布:mac安装windows10失败 编辑:程序博客网 时间:2024/06/06 09:06
前言
本系列主要主要是记录下Tensorflow在RNN实现这一块的相关代码,不做详细解释,主要是翻译加笔记。
RNNCell
在Tensorflow中,定义了一个RNNCell的抽象类,具体的所有不同类型的RNN Cell都是基于这个类的,所以就首先讲一下这个,下面是基本的代码:
class RNNCell(object): def __call__(self, inputs, state, scope=None): raise NotImplementedError("Abstract method") @property def state_size(self): raise NotImplementedError("Abstract method") @property def output_size(self): raise NotImplementedError("Abstract method") def zero_state(self, batch_size, dtype): state_size = self.state_size if nest.is_sequence(state_size): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), dtype=dtype) for s in state_size_flat] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) return zeros
在Tensorflow中,Cell的定义不同于其他资料当中的定义,在其他的文档中Cell(下文指代为L-Cell)被看做是一个能够产生Single Scalar输出的对象,而在这里则是一个包含一系列L-Cell的水平数组。
具体到RNNCell,RNNCell是一个包含一个State(状态)并且能够执行一些处理输入矩阵的对象。RNNCell将输入的矩阵(Input Matrix),运算输出一个包含”self.output”列的输出矩阵(Ouput Matrix)。如果定义了“self.state_size”这个属性,并且取值为一个整数,那么RNNCell则会同时输出一个状态矩阵(State Matrix),包含“self.state_size”列。而如果“self.state_size”定义为一个整数的Tuple,,那么则是输出对应长度的状态矩阵的Tuple,Tuple中的每一个状态矩阵长度还是和之前的一样,包含“self.state_size”列。
在Tensorflow中,将会基于整个RNNCell实现一系列常用的RNNCell,比如LSTM和GRU,并且将会支持包含Dropout等在内的特性,同时也支持构建多层的RNN网络。
RNNCell基本结构
RNNCell有一些基本的属性需要设置:
state_size: 说明这个Cell使用的State的大小output_size: 这个RNNCell最后生成结果的大小
对于每一个RNNCell的具体实现类,都必须要实现__call__这个方法:
每一个具体的RNN类必须实现的方法:def __call__(self, inputs, state, scope=None):
这个方法是RNNCell的核心方法,其需要的属性有:
inputs: 这个需要输入一个形状为[batch_size,input_size]的2D Tensor,batch_size是你训练中指定的batch的大小,而input_size则是输入数据的维度state: state就是你rnn网络中rnn cell的状态,比如说如果你的rnn定义包含了N个单元(也就是你的self.state_size是个整数N),那么在你每次执行RNN网络时就应该给一个[batch_size,self.state_size]形状的2D Tensor来表示当前RNN网络的状态,而如果你的self.state_size是一个元祖,那么给定的状态也应该是一个Tuple,每个Tuple里的状态表示和之前的方式一样,只要注意好不同的self.state_size取值就好
而RNN Cell经过一系列的工作后,将会输出如下的东西:
output:对应的你的batch的大小和输出大小的结果,形状是[batch_size x self.output_size]state:根据你的self.state_size的不同,输出一个更新后的RNN状态,或者一个Tuple的状态,格式对应输入的state
同时RNNCell还定义了一个非抽象的方法,那就是生成初始化状态的方法,比较简单就不多说了:
def zero_state(self, batch_size, dtype):
BasicRNNCell
下面介绍完了RNNCell的定义,我们来看一个最原始的RNN的实现,就是不涉及到LSTM,GRU的那种。这种RNNCell被称作BasicRNNCell,代码很简短:
class BasicRNNCell(RNNCell): """The most basic RNN cell.""" def __init__(self, num_units, input_size=None, activation=tanh): if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def __call__(self, inputs, state, scope=None): """Most basic RNN: output = new_state = activation(W * input + U * state + B).""" with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell" output = self._activation(_linear([inputs, state], self._num_units, True)) return output, output
在最基本的RNN实现当中,RNN在时间t的输出,就是其在时间t的状态
output = new_state = activation(W * input + U * state + B)
这个计算就直接在__call__中计算完成了,这个函数比较简单,但是他具体如何计算则调用了一个方法,不在类中,那么我们看看这个函数先:
_linear([inputs, state], self._num_units, True)对应函数介绍,_liner的功能就是你给了一个或一系列的Tensor(A,B,C.....),他给你计算一个W1*A+W2*B.....+Bias的结果存在,比如输入[input,state],那么这个方法就是计算W * input + U * state:def _linear(args, output_size, bias, bias_start=0.0, scope=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_start: starting value to initialize the bias; 0 by default. scope: VariableScope for the created subgraph; defaults to "Linear". Returns: A 2D Tensor with shape [batch x output_size] equal to sum_i(args[i] * W[i]), where W[i]s are newly created matrices. Raises: ValueError: if some of the arguments has unspecified or wrong shape.
到此,关于Tensorflow里面RNNCell的基本结构,以及BasicRNNCell的源码分析结束。
以上,MebiuW
- Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
- Tensorflow RNN源代码解析笔记2:RNN的基本实现
- tensorflow中RNNcell源码分析以及自定义RNNCell的方法
- Tensorflow: RNNCell
- 基于tensorflow的RNN-LSTM(一)实现RNN
- RNN的原理与TensorFlow代码实现
- tensorflow rnn阅读笔记
- tensorflow 实现rnn
- Tensorflow源码解析系列--RNN
- Tensorflow学习笔记--RNN精要及代码实现
- RNN入门详解及TensorFlow源码实现--深度学习笔记
- 学习笔记TF050:TensorFlow源代码解析
- tensorflow Examples:<4>实现RNN
- TensorFlow中RNN网络的实现和关键参数选择
- TensorFlow中RNN实现的正确打开方式
- Resnet的Tensorflow实现源代码
- [TensorFlow学习笔记1]TensorFLow的基本概念和基本使用
- 使用TensorFlow实现RNN模型入门篇1
- mysql5.7解压版的安装与配置
- 实践与wiki教程对比学习ROS(catkin/package.xml)
- L1-003. 个位数统计
- 05:输出保留12位小数的浮点数
- 判断TextView是否有内容省略
- Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
- java字符串转换成时间Unparseable date错误的解决方案
- C++抽象编程——递归简介(2)——阶乘函数的执行分析
- tcp请求建立连接,结束连接握手过程
- Spring源码解读前篇--Spring容器的设计
- 读书笔记 effective c++ Item 25 实现一个不抛出异常的swap
- jstl获取list的长度大小
- Darwin Stream server(DSS服务器)的Relay(中继/转发)设置
- 生成排列 Generating Permutations