tensorflow中RNNcell源码分析以及自定义RNNCell的方法
来源:互联网 发布:淘宝商品不含区间价格 编辑:程序博客网 时间:2024/05/18 06:32
我们在仿真一些论文的时候经常会遇到一些模型,对RNN或者LSTM进行了少许的修改,或者自己定义了一种RNN的结构等情况,比如前面介绍的几篇memory networks的论文,往往都需要按照自己定义的方法来构造RNN网络。所以本篇博客就主要总结一下RNNcell的用法以及如何按照自己的需求自定义RNNCell。
tf中RNNCell的用法介绍
我们直接从源码的层面来看一看tf是如何实现RNNCell定义的。代码入下:
class RNNCell(base_layer.Layer): def __call__(self, inputs, state, scope=None): if scope is not None: with vs.variable_scope(scope, custom_getter=self._rnn_get_variable) as scope: return super(RNNCell, self).__call__(inputs, state, scope=scope) else: with vs.variable_scope(vs.get_variable_scope(), custom_getter=self._rnn_get_variable): return super(RNNCell, self).__call__(inputs, state) def _rnn_get_variable(self, getter, *args, **kwargs): variable = getter(*args, **kwargs) trainable = (variable in tf_variables.trainable_variables() or (isinstance(variable, tf_variables.PartitionedVariable) and list(variable)[0] in tf_variables.trainable_variables())) if trainable and variable not in self._trainable_weights: self._trainable_weights.append(variable) elif not trainable and variable not in self._non_trainable_weights: self._non_trainable_weights.append(variable) return variable @property def state_size(self): raise NotImplementedError("Abstract method") @property def output_size(self): raise NotImplementedError("Abstract method") def build(self, _): pass def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): state_size = self.state_size return _zero_state_tensors(state_size, batch_size, dtype)
RNNCell是一个抽象的父类,其他的RNNcell都会继承该方法,然后具体实现其中的call()函数。从上面的定义中我们发现其主要有state_size和output_size两个属性,分别代表了隐藏层和输出层的维度。然后就是zero_state()和call()两个函数,分别用于初始化初始状态h0为全零向量和定义实际的RNNCell的操作(比如RNN就是一个激活,GRU的两个门,LSTM的三个门控等,不同的RNN的区别主要体现在这个函数)。有了这个抽象类,我们接下来看看tf中BasicRNNCell、GRUCell、BasicLSTMCell三个cell的定义方法,了解不同变种RNN模型的定义方式的区别和实现方法。
class BasicRNNCell(RNNCell): def __init__(self, num_units, activation=None, reuse=None): super(BasicRNNCell, self).__init__(_reuse=reuse) self._num_units = num_units self._activation = activation or math_ops.tanh @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def call(self, inputs, state): output = self._activation(_linear([inputs, state], self._num_units, True)) return output, output
最简单的RNN结构如上图所示,可以看出BasicRNNCell中把state_size和output_size定义成相同,而且ht和output也是相同的(看call函数的输出是两个output,也就是其并未定义输出部分)。再看一下其主要功能实现就是call函数的第一行,就是input和前一时刻状态state经过一个线性函数在经过一个激活函数即可,也就是最普通的RNN定义方式。也就是说output = new_state = f(W * input + U * state + B)
。接下来我们看一下GRU的定义:
class GRUCell(RNNCell): def __init__(self, num_units, activation=None, reuse=None, kernel_initializer=None, bias_initializer=None): super(GRUCell, self).__init__(_reuse=reuse) self._num_units = num_units self._activation = activation or math_ops.tanh self._kernel_initializer = kernel_initializer self._bias_initializer = bias_initializer @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def call(self, inputs, state): with vs.variable_scope("gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. bias_ones = self._bias_initializer if self._bias_initializer is None: dtype = [a.dtype for a in [inputs, state]][0] bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) value = math_ops.sigmoid( _linear([inputs, state], 2 * self._num_units, True, bias_ones, self._kernel_initializer)) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) with vs.variable_scope("candidate"): c = self._activation( _linear([inputs, r * state], self._num_units, True, self._bias_initializer, self._kernel_initializer)) new_h = u * state + (1 - u) * c return new_h, new_h
相比BasicRNNCell只改变了call函数部分,增加了重置门和更新门两部分,分别由r和u表示。然后c表示要更新的状态值。其对应的公式如如下所示:
r = f(W1 * input + U1 * state + B1) u = f(W2 * input + U2 * state + B2) c = f(W3 * input + U3 * r * state + B3) h_new = u * h + (1 - u) * c
接下来再看一下BasicLSTMCell的实现方法,相比GRU,LSTM又多了一个输出门,而且又新增添了一个C表示其内部状态,然后将h和c以tuple的形式返回作为LSTM内部的状态变量。其内部结构和公式表示如下图所示:
class BasicLSTMCell(RNNCell): def __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None): super(BasicLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation or math_ops.tanh @property def state_size(self): return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units def call(self, inputs, state): sigmoid = math_ops.sigmoid # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) concat = _linear([inputs, h], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = ( c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state
从上面的代码可以发现,其与BasicRNNCell和GRUCell的区别也主要在call()函数上,不同的功能实现也都在call里面进行。不难发现,还有一个在三个累里面都频繁使用到的函数_linear(),这个函数的作用是什么呢,我们先来看一下其源码:
def _linear(args, output_size, bias, bias_initializer=None, kernel_initializer=None): if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 1. total_arg_size = 0 shapes = [a.get_shape() for a in args] for shape in shapes: if shape.ndims != 2: raise ValueError("linear is expecting 2D arguments: %s" % shapes) if shape[1].value is None: raise ValueError("linear expects shape[1] to be provided for shape %s, " "but saw %s" % (shape, shape[1])) else: total_arg_size += shape[1].value dtype = [a.dtype for a in args][0] # Now the computation. scope = vs.get_variable_scope() with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], dtype=dtype, initializer=kernel_initializer) if len(args) == 1: res = math_ops.matmul(args[0], weights) else: res = math_ops.matmul(array_ops.concat(args, 1), weights) if not bias: return res with vs.variable_scope(outer_scope) as inner_scope: inner_scope.set_partitioner(None) if bias_initializer is None: bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) biases = vs.get_variable( _BIAS_VARIABLE_NAME, [output_size], dtype=dtype, initializer=bias_initializer) return nn_ops.bias_add(res, biases)
这个函数的输入args就是[input, state],而output_size是输出层的大小,我们可以看到BasicRNNCell中,output_size就是_num_units,GRUCell中是2*_num_units,BasicLSTMCell中是4*_num_units,这是因为_linear中执行的是RNN中的几个等式的Wx+Uh+B的功能,但是不同的RNN中数量不同,比如LSTM中需要计算四次,然后直接把output_size定义为4*_num_units,再把输出进行拆分成四个变量即可~~
到这里我们就简单分析了一下tensorflow中不同RNN的实现方法,接下来我们就要看一看如何实现自己模型中所需要的RNNCell。
tf中自定义RNNCell的方法
Recurrent Entity Networks
看完GRU和LSTM cell的实现方案,我觉得应该不难想象出自定义RNNCell的方法,那就是继承RNNCell这个抽象类,然后实现init、state_size、output_size、call四个函数就行了,其中在call函数中实现自己需要的功能即可。我们来结合之前仿真过得Recurrent Entity Networks这篇文章中使用的带来来介绍一下,该模型每个cell中包含m个slot,也就是m个记忆,每个记忆都是一个mem_sz维的向量,然后每个slot都有一个key,用来做索引。其公式如下所示:
class DynamicMemory(tf.contrib.rnn.RNNCell): def __init__(self, memory_slots, memory_size, keys, activation=prelu, initializer=tf.random_normal_initializer(stddev=0.1)): """ Instantiate a DynamicMemory Cell, with the given number of memory slots, and key vectors. :param memory_slots: Number of memory slots to initialize. :param memory_size: Dimensionality of memories => tied to embedding size. :param keys: List of keys to seed the Dynamic Memory with (can be random). :param initializer: Variable Initializer for Cell Parameters. """ self.m, self.mem_sz, self.keys = memory_slots, memory_size, keys self.activation, self.init = activation, initializer # 公式2中的三个参数,在所有RNN Cell中共享。 self.U = tf.get_variable("U", [self.mem_sz, self.mem_sz], initializer=self.init) self.V = tf.get_variable("V", [self.mem_sz, self.mem_sz], initializer=self.init) self.W = tf.get_variable("W", [self.mem_sz, self.mem_sz], initializer=self.init) @property def state_size(self): return [self.mem_sz for _ in range(self.m)] @property def output_size(self): return [self.mem_sz for _ in range(self.m)] def zero_state(self, batch_size, dtype): return [tf.tile(tf.expand_dims(key, 0), [batch_size, 1]) for key in self.keys] def __call__(self, inputs, state, scope=None): """ Run the Dynamic Memory Cell on the inputs, updating the memories with each new time step. :param inputs: 2D Tensor of shape [bsz, mem_sz] representing a story sentence. :param states: List of length M, each with 2D Tensor [bsz, mem_sz] => h_j (starts as key). """ new_states = [] #下面的循环表示m个memory slot,对每个slot都执行相同的操作 for block_id, h in enumerate(state): # 下面三行主要实现公式1,即门函数g的计算 content_g = tf.reduce_sum(tf.multiply(inputs, h), axis=[1]) # Shape: [bsz] address_g = tf.reduce_sum(tf.multiply(inputs, tf.expand_dims(self.keys[block_id], 0)), axis=[1]) # Shape: [bsz] g = sigmoid(content_g + address_g) #下面四行主要是公式2的计算,根据输入s和记忆h得到新的记忆h_ h_component = tf.matmul(h, self.U) # Shape: [bsz, mem_sz] w_component = tf.matmul(tf.expand_dims(self.keys[block_id], 0), self.V) # Shape: [1, mem_sz] s_component = tf.matmul(inputs, self.W) # Shape: [bsz, mem_sz] candidate = self.activation(h_component + w_component + s_component) # Shape: [bsz, mem_sz] # 将新的记忆h_与门空函数g相乘之后的结果加到原始的记忆h中 new_h = h + tf.multiply(tf.expand_dims(g, -1), candidate) # Shape: [bsz, mem_sz] #对记忆h进行归一化 new_h_norm = tf.nn.l2_normalize(new_h, -1) # Shape: [bsz, mem_sz] new_states.append(new_h_norm) return new_states, new_states
上面这种方式定义的cell,直接调用tf.nn.dynamic_rnn()函数就可以进行unrolling来构建模型了。
Neural Turing Machines
除此之外,我们还可以完全自定义cell,不继承RNNCell,我们可以先来看一下官网给出的RNNCell的定义,其实只要求有一个call函数即可。
Every RNNCell must have the properties below and implement call with the signature (output, next_state) = call(input, state). The optional third input argument, scope, is allowed for backwards compatibility purposes; but should be left off for new subclasses.
有的时候我们可能会有更多的需求,这是我们可以不继承RNNCell,直接定义一个类即可,不过有的时候就无法调用tf.rnn.dynamic_rnn函数来进行自动化建模,而需要自己写函数进行循环调用从而实现unrolling的效果。我们可以结合ntm的代码进行介绍。
cell = ntm_cell.NTMCell(args.rnn_size, args.memory_size, args.memory_vector_dim, 1, 1, addressing_mode='content_and_location', reuse=reuse, output_dim=args.vector_dim) state = cell.zero_state(args.batch_size, tf.float32) self.state_list = [state] for t in range(seq_length): output, state = cell(tf.concat([self.x[:, t, :], np.zeros([args.batch_size, 1])], axis=1), state) self.state_list.append(state) output, state = cell(eof, state) self.state_list.append(state)
上面这几行代码是先创建NTMCell的对象,然后接下来初始化全零状态,再就是循环调用cell的call函数,并将中间的state保存下来即可。NTMCell的定义方式如下所示,不需要继承RNNCell,而是全部自定义的方法来实现。
class NTMCell(): def __init__(self, rnn_size, memory_size, memory_vector_dim, read_head_num, write_head_num, addressing_mode='content_and_loaction', shift_range=1, reuse=False, output_dim=None): self.rnn_size = rnn_size self.memory_size = memory_size self.memory_vector_dim = memory_vector_dim self.read_head_num = read_head_num self.write_head_num = write_head_num self.addressing_mode = addressing_mode self.reuse = reuse self.controller = tf.nn.rnn_cell.BasicRNNCell(self.rnn_size) self.step = 0 self.output_dim = output_dim self.shift_range = shift_range def __call__(self, x, prev_state): prev_read_vector_list = prev_state['read_vector_list'] # read vector in Sec 3.1 (the content that is # read out, length = memory_vector_dim) prev_controller_state = prev_state['controller_state'] # state of controller (LSTM hidden state) # x + prev_read_vector -> controller (RNN) -> controller_output controller_input = tf.concat([x] + prev_read_vector_list, axis=1) with tf.variable_scope('controller', reuse=self.reuse): controller_output, controller_state = self.controller(controller_input, prev_controller_state) num_parameters_per_head = self.memory_vector_dim + 1 + 1 + (self.shift_range * 2 + 1) + 1 num_heads = self.read_head_num + self.write_head_num total_parameter_num = num_parameters_per_head * num_heads + self.memory_vector_dim * 2 * self.write_head_num with tf.variable_scope("o2p", reuse=(self.step > 0) or self.reuse): o2p_w = tf.get_variable('o2p_w', [controller_output.get_shape()[1], total_parameter_num], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5)) o2p_b = tf.get_variable('o2p_b', [total_parameter_num], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5)) parameters = tf.nn.xw_plus_b(controller_output, o2p_w, o2p_b) head_parameter_list = tf.split(parameters[:, :num_parameters_per_head * num_heads], num_heads, axis=1) erase_add_list = tf.split(parameters[:, num_parameters_per_head * num_heads:], 2 * self.write_head_num, axis=1) # k, beta, g, s, gamma -> w prev_w_list = prev_state['w_list'] # vector of weightings (blurred address) over locations prev_M = prev_state['M'] w_list = [] p_list = [] for i, head_parameter in enumerate(head_parameter_list): # Some functions to constrain the result in specific range # exp(x) -> x > 0 # sigmoid(x) -> x \in (0, 1) # softmax(x) -> sum_i x_i = 1 # log(exp(x) + 1) + 1 -> x > 1 k = tf.tanh(head_parameter[:, 0:self.memory_vector_dim]) beta = tf.sigmoid(head_parameter[:, self.memory_vector_dim]) * 10 # do not use exp, it will explode! g = tf.sigmoid(head_parameter[:, self.memory_vector_dim + 1]) s = tf.nn.softmax( head_parameter[:, self.memory_vector_dim + 2:self.memory_vector_dim + 2 + (self.shift_range * 2 + 1)] ) gamma = tf.log(tf.exp(head_parameter[:, -1]) + 1) + 1 with tf.variable_scope('addressing_head_%d' % i): w = self.addressing(k, beta, g, s, gamma, prev_M, prev_w_list[i]) # Figure 2 w_list.append(w) p_list.append({'k': k, 'beta': beta, 'g': g, 's': s, 'gamma': gamma}) # Reading (Sec 3.1) read_w_list = w_list[:self.read_head_num] read_vector_list = [] for i in range(self.read_head_num): read_vector = tf.reduce_sum(tf.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1) read_vector_list.append(read_vector) # Writing (Sec 3.2) write_w_list = w_list[self.read_head_num:] M = prev_M for i in range(self.write_head_num): w = tf.expand_dims(write_w_list[i], axis=2) erase_vector = tf.expand_dims(tf.sigmoid(erase_add_list[i * 2]), axis=1) add_vector = tf.expand_dims(tf.tanh(erase_add_list[i * 2 + 1]), axis=1) M = M * (tf.ones(M.get_shape()) - tf.matmul(w, erase_vector)) + tf.matmul(w, add_vector) # controller_output -> NTM output if not self.output_dim: output_dim = x.get_shape()[1] else: output_dim = self.output_dim with tf.variable_scope("o2o", reuse=(self.step > 0) or self.reuse): o2o_w = tf.get_variable('o2o_w', [controller_output.get_shape()[1], output_dim], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5)) o2o_b = tf.get_variable('o2o_b', [output_dim], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5)) NTM_output = tf.nn.xw_plus_b(controller_output, o2o_w, o2o_b) state = { 'controller_state': controller_state, 'read_vector_list': read_vector_list, 'w_list': w_list, 'p_list': p_list, 'M': M } self.step += 1 return NTM_output, state def addressing(self, k, beta, g, s, gamma, prev_M, prev_w): # Sec 3.3.1 Focusing by Content # Cosine Similarity k = tf.expand_dims(k, axis=2) inner_product = tf.matmul(prev_M, k) k_norm = tf.sqrt(tf.reduce_sum(tf.square(k), axis=1, keep_dims=True)) M_norm = tf.sqrt(tf.reduce_sum(tf.square(prev_M), axis=2, keep_dims=True)) norm_product = M_norm * k_norm K = tf.squeeze(inner_product / (norm_product + 1e-8)) # eq (6) # Calculating w^c K_amplified = tf.exp(tf.expand_dims(beta, axis=1) * K) w_c = K_amplified / tf.reduce_sum(K_amplified, axis=1, keep_dims=True) # eq (5) if self.addressing_mode == 'content': # Only focus on content return w_c # Sec 3.3.2 Focusing by Location g = tf.expand_dims(g, axis=1) w_g = g * w_c + (1 - g) * prev_w # eq (7) s = tf.concat([s[:, :self.shift_range + 1], tf.zeros([s.get_shape()[0], self.memory_size - (self.shift_range * 2 + 1)]), s[:, -self.shift_range:]], axis=1) t = tf.concat([tf.reverse(s, axis=[1]), tf.reverse(s, axis=[1])], axis=1) s_matrix = tf.stack( [t[:, self.memory_size - i - 1:self.memory_size * 2 - i - 1] for i in range(self.memory_size)], axis=1 ) w_ = tf.reduce_sum(tf.expand_dims(w_g, axis=1) * s_matrix, axis=2) # eq (8) w_sharpen = tf.pow(w_, tf.expand_dims(gamma, axis=1)) w = w_sharpen / tf.reduce_sum(w_sharpen, axis=1, keep_dims=True) # eq (9) return w def zero_state(self, batch_size, dtype): def expand(x, dim, N): return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim) with tf.variable_scope('init', reuse=self.reuse): state = { 'controller_state': expand(tf.tanh(tf.get_variable('init_state', self.rnn_size, initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size), 'read_vector_list': [expand(tf.nn.softmax(tf.get_variable('init_r_%d' % i, [self.memory_vector_dim], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size) for i in range(self.read_head_num)], 'w_list': [expand(tf.nn.softmax(tf.get_variable('init_w_%d' % i, [self.memory_size], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size) if self.addressing_mode == 'content_and_loaction' else tf.zeros([batch_size, self.memory_size]) for i in range(self.read_head_num + self.write_head_num)], 'M': expand(tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim], initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), dim=0, N=batch_size) } return state
至此我们就结合两个实例分析了一下在tensorflow中自定义RNNCell的两种方法,希望对大家在使用tf编程的时候有所帮助~~
- tensorflow中RNNcell源码分析以及自定义RNNCell的方法
- Tensorflow: RNNCell
- Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
- Tensorflow ValueError: Attempt to reuse RNNCell with a different variable scope than its first
- 实测用LSTMcell替换掉最简单的RNNcell的意义
- jQuery源码分析之offset,position,offsetParent方法以及源码中常见的cssHooks,swap代码
- C#中Queue<T>类的使用以及部分方法的源码分析
- Sqlite创建database的两种方法,以及源码分析,以及抽象类如何在SqliteHelper中应用
- Android 自定义控件源码分析----谈Android自定义控件中 onMeasure()方法处理 wrap_content 情况的必要性
- ValueError: Attempt to have a second RNNCell use the weights of a variable scope already has weights
- 【Deep Learning】YOLO_v1 的 TensorFlow 源码分析
- Android中AsyncTask的源码分析以及实例
- openstack中vnc的流程以及源码分析
- android-进阶(3)-自定义view(2)-Android中View绘制流程以及相关方法的分析
- 在Android源码中扩展自定义View的方法
- tensorflow 源码编译安装以及遇到的一些错误
- 【Flume】【源码分析】flume中http监控类型的源码分析,度量信息分析,以及flume的事件总线
- TensorFlow中global_step的简单分析
- Kafka设计解析(六)- Kafka高性能架构之道
- 【赛后补题】ccpc2107秦皇岛H题
- 由IO流关闭引发的关于垃圾回收机制及finalize()的理解
- IMU误差研究
- 水平线与垂直线提取
- tensorflow中RNNcell源码分析以及自定义RNNCell的方法
- 十月份新知识点总结
- java 根据excel模板格式导出指定格式的excel
- vue-cli#2.0 webpack 配置分析
- 虚拟机(VMware)之在Ubuntu下安装VMware Tools
- spirngboot之Sqlite数据库
- REST架构风格
- POJ1189钉子和小球,人人为我 我为人人 longlong 位操作
- HTML打字机和字体发光效果