merge_all引发的血案
来源:互联网 发布:网红的淘宝店知乎 编辑:程序博客网 时间:2024/04/30 10:31
merge_all引发的血案
- 在训练深度神经网络的时候,我们经常会使用Dropout,然而在
test
的时候,需要把dropout
撤掉.为了应对这种问题,我们通常要建立两个模型,让他们共享变量。详情. - 为了使用
Tensorboard
来可视化我们的数据,我们会经常使用Summary
,最终都会用一个简单的merge_all
函数来管理我们的Summary
错误示例
当这两种情况相遇时,bug
就产生了,看代码:
import tensorflow as tfimport numpy as npclass Model(object): def __init__(self): self.graph() self.merged_summary = tf.summary.merge_all()# 引起血案的地方 def graph(self): self.x = tf.placeholder(dtype=tf.float32,shape=[None,1]) self.label = tf.placeholder(dtype=tf.float32, shape=[None,1]) w = tf.get_variable("w",shape=[1,1]) self.predict = tf.matmul(self.x,w) self.loss = tf.reduce_mean(tf.reduce_sum(tf.square(self.label-self.predict),axis=1)) self.train_op = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss) tf.summary.scalar("loss",self.loss)def run_epoch(session, model): x = np.random.rand(1000).reshape(-1,1) label = x*3 feed_dic = {model.x.name:x, model.label:label} su = session.run([model.merged_summary], feed_dic)def main(): with tf.Graph().as_default(): with tf.name_scope("train"): with tf.variable_scope("var1",dtype=tf.float32): model1 = Model() with tf.name_scope("test"): with tf.variable_scope("var1",reuse=True,dtype=tf.float32): model2 = Model() with tf.Session() as sess: tf.global_variables_initializer().run() run_epoch(sess,model1) run_epoch(sess,model2)if __name__ == "__main__": main()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
运行情况是这样的: 执行run_epoch(sess,model1)
时候,程序并不会报错,一旦执行到run_epoch(sess,model1)
,就会报错(错误信息见文章最后)。
错误原因
看代码片段:
class Model(object): def __init__(self): self.graph() self.merged_summary = tf.summary.merge_all()# 引起血案的地方...with tf.name_scope("train"): with tf.variable_scope("var1",dtype=tf.float32): model1 = Model() # 这里的merge_all只是管理了自己的summarywith tf.name_scope("test"): with tf.variable_scope("var1",reuse=True,dtype=tf.float32): model2 = Model()# 这里的merge_all管理了自己的summary和上边模型的Summary
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
由于Summary
的计算是需要feed
数据的,所以会报错。
解决方法
我们只需要替换掉merge_all
就可以解决这个问题。看代码
class Model(object): def __init__(self,scope): self.graph() self.merged_summary = tf.summary.merge( tf.get_collection(tf.GraphKeys.SUMMARIES,scope) )...with tf.Graph().as_default(): with tf.name_scope("train") as train_scope: with tf.variable_scope("var1",dtype=tf.float32): model1 = Model(train_scope) with tf.name_scope("test") as test_scope: with tf.variable_scope("var1",reuse=True,dtype=tf.float32): model2 = Model(test_scope)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
关于tf.get_collection
地址
当有多个模型时,出现类似错误,应该考虑使用的方法是不是涉及到了其他的模型
error
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor ‘train/var1/Placeholder’ with dtype float
[Node: train/var1/Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device=”/job:localhost/replica:0/task:0/gpu:0”]]
阅读全文
0 0
- merge_all引发的血案
- tensorflow学习笔记(二十九):merge_all引发的血案
- ActiveX引发的“血案”
- size_t引发的血案
- 一个 * 引发的血案
- gets引发的血案
- Print 引发的“血案”
- lease引发的血案
- 一个“-”引发的血案
- MD5引发的血案
- 一个"/"引发的血案
- wrap_content引发的血案
- PersistableBundle引发的血案
- 看球引发的血案
- 一个松果引发的血案
- 一个memset引发的血案
- 一条语句引发的血案
- 一条短信引发的血案
- window.opener用法
- 【BigHereo 23】---L1---C++对象
- STC12手册通读
- 错误 -7
- 在公有类中使用访问方法而非公有域。
- merge_all引发的血案
- 配置yum源教程
- vector用法
- MIT18.06线性代数课程笔记6:vector space,subspace,column space,null space
- 软件测试职业规划
- LeetCode 316. Remove Duplicate Letters--贪心算法
- ubuntu上安装gnome桌面
- 用JAVA 实现图像化的模式串匹配并于文本区显示颜色
- WebLogic11g在startWebLogic.cmd文件中配置jar包启动