给RNN cell加dropout

来源:互联网 发布:苹果手机数据恢复公司 编辑:程序博客网 时间:2024/05/17 10:29
class SwitchableDropoutWrapper(DropoutWrapper):    def __init__(self, cell, is_train, input_keep_prob=1.0, output_keep_prob=1.0,             seed=None):        super(SwitchableDropoutWrapper, self).__init__(cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob,                                                       seed=seed)        self.is_train = is_train    def __call__(self, inputs, state, scope="dropout_rnn"):        outputs_do, new_state_do = super(SwitchableDropoutWrapper, self).__call__(inputs, state, scope=scope)        tf.get_variable_scope().reuse_variables()        outputs, new_state = self._cell(inputs, state, scope)        outputs = tf.cond(self.is_train, lambda: outputs_do, lambda: outputs)        if isinstance(state, tuple):            new_state = state.__class__(*[tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i)                                          for new_state_do_i, new_state_i in zip(new_state_do, new_state)])        else:            new_state = tf.cond(self.is_train, lambda: new_state_do, lambda: new_state)        return outputs, new_state

注意tf.get_variable_scope().reuse_variables()这句可能会引起Variable XXX does not exist, or was not created with tf.get_variable()的错误,按需调整

原创粉丝点击