tf.get_collection获取训练变量等效用法

来源:互联网 发布:学校三级公共卫生网络 编辑:程序博客网 时间:2024/06/05 20:56
#    train_vars=tf.trainable_variables()#    g_vars=[var for var in train_vars if var.name.startswith('generator')]#    d_vars=[var for var in train_vars if var.name.startswith('discriminator')]    g_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')    d_vars=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
原创粉丝点击