TensorLayer 如何重复使用 variable

来源:互联网 发布:学校机房还原软件 编辑:程序博客网 时间:2024/04/29 20:41

用 TensorFlow 比较多的同学,会发现 reuse variable 来建立模型 (graph) 有时候是必须的,比如建立RNN模型时,num_steps 在 training 和 testing 的时候往往是不同的。


以 dropout 为例: 在 testing 的时候,应该是关闭的。而在 training的时候是启用的。

TensorLayer 有2个简单的方法解决实现这个

方法1  这下面连接的例子中,Layer 内部通过 placeholder 来设置dropout keeping probabilities,当testing的时候,probabilities被设为 1

TensorLayer simple example :    http://tensorlayercn.readthedocs.io/zh/latest/user/tutorial.html#tensorlayer

# 训练网络
tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
            acc=acc, batch_size=500, n_epoch=500, print_freq=5,
            X_val=X_val, y_val=y_val, eval_train=False)

# 测试网络
tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost)


方法2 上面的代码使用了TensorLayer 提供的傻瓜式函数,和keras、tflearn差不多。但 TensorLayer 作者鼓励大家使用 TensorFlow 的原生方法。

TensorLayer MNIST examples :    https://github.com/zsdonghao/tensorlayer/blob/master/tutorial_mnist.py

# 训练时启动dropout            

feed_dict = {x: X_train_a, y_: y_train_a}
            feed_dict.update( network.all_drop )    # enable dropout or dropconnect layers
            sess.run(train_op, feed_dict=feed_dict)

# 测试时关闭dropout,把probabilities 全设为1,放入feed_dict

dp_dict = tl.utils.dict_to_one( network.all_drop )
            feed_dict = {x: X_val, y_: y_val}
            feed_dict.update(dp_dict)


比如重复使用 variable 的情况:RNN为例,除了dropout 关闭外,还需要使用不同的 num_steps,那就一定要建立不同的 computation graph 了。

graph reuse variables。可以这样实现:

TensorLayer PTB tutorial:   https://github.com/zsdonghao/tensorlayer/blob/master/tutorial_ptb_lstm.py

这个代码中,最关键的代码是:

def inference(x, is_training, num_steps, reuse=None):

with tf.variable_scope("model", reuse=reuse):   # reuse=True时,则让TensorFlow 知道 reuse variable
            tl.layers.set_name_reuse(reuse)            # reuse = True时,TensorLayer allows reuse the same Layer name

network = tl.layers.EmbeddingInputlayer(.....

.....

......


这样,你就可以如下建立多个 graph with the same variables了。

# Inference for Training
network, lstm1, lstm2 = inference(input_data,  is_training=True, num_steps=num_steps, reuse=None)


# Inference for Validating
network_val, lstm1_val, lstm2_val = inference(input_data,  is_training=False, num_steps=num_steps, reuse=True)


# Inference for Testing (Evaluation)
network_test, lstm1_test, lstm2_test = inference(input_data_test, is_training=False, num_steps=1, reuse=True)



0 0
原创粉丝点击