Tensorflow 学习笔记之 共享变量(Sharing Variables)
来源:互联网 发布:淘宝优惠券微信群起名 编辑:程序博客网 时间:2024/05/23 13:26
Tensorflow 学习笔记之 共享变量(Sharing Variables)
最近两年,谷歌撑腰的深度学习框架Tensorflow发展地如日中天,虽然17年pytorch的出现略微“打压”了一些TF的势头,但TF在深度学习界的地位还是难以撼动的,github上TF的收藏量一直稳在深度学习中前二的位置。个人在4月份开始接触TF,写分类、超分辨网络不亦乐乎。然而,最近从越来越多的TF github项目中看到了人们都在使用一个叫“共享变量”的机制管理变量,已经基本学会简单TF语法的我,今天决定好好研究一下这个功能。
变量管理的问题
设想你要写一个分类网络,结构是“卷积->ReLU->Pooling->卷积->ReLU->Pooling->展平->全连接->ReLU->全连接->Softmax”。由于网络实在太简单了,写起来完全不需要过多的思考。可能你是这么写的(例子出自TF官网:http://tensorflow.org/tutorials/mnist/pros/index.html):
def weight_variable(shape): return tf.Variable(tf.truncated_normal(shape, stddev=0.1))def bias_variable(shape): return tf.Variable(tf.constant(0.1, shape=shape))W_conv1 = weight_variable([5, 5, 3, 32])b_conv1 = bias_variable([32])h_conv1 = tf.nn.relu(tf.nn.conv2d(...))h_pool1 = tf.nn.max_pool(h_conv1,...)W_conv2 = weight_variable([5, 5, 32, 64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(tf.nn.conv2d(...))h_pool2 = tf.nn.max_pool(h_conv2,...)W_fc1 = weight_variable([7 * 7 * 64, 1024])b_fc1 = bias_variable([1024])h_flat = tf.reshape(...)h_fc1 = tf.nn.relu(tf.matmul(...))
从中应该可以看出来,如果需要添加卷积层或全连接层,需要额外定义相应的权重w和偏置b。因此就有了[W_conv1,b_conv1,W_conv2,b_conv2,…]这一串变量信息。
那么问题来了,如果让你写一个19层的VGG网络,甚至是上百层的Resnet呢?这种定义方法显然是行不通的,等手动把[W_conv1,b_conv1,W_conv2,b_conv2,…]这些东西输完,估计也对TF丧失兴趣了。你可能会想到这种循环的方法:
def layer(shape, ...): w = tf.Variable(tf.truncated_normal(shape, stddev=0.1)) b = tf.Variable(tf.constant(0.1, shape=shape)) return tf.nn.relu(tf.nn.conv2d(...))for i in range(19): ... x = layer(shape, ...) ...
这样的确就不用一个个写[w1,w2,w3,w4,….]这些变量了,从某种程度上来看确实解放了双手。但是,如果我现在想读取第8个卷积层中w和b的数值,有没有什么简单的方法呢?再或者我想把这个网络中的参数转移到另一个完全相同的网络中使用呢?虽然你可以再定义一组列表var,在每次新定义变量后var.append(w),但从管理变量和传输变量的角度来看依旧不是很方便。
万幸的是,TF早就想到了这一点,并且提供Variable Scope机制来帮助管理变量。有了这个工具,就再也不用为变量的定义和共享伤脑筋了。
常用函数
tf.get_variable():和tf.Variable类似,该函数也是为了创建一个变量。参数有:
- name:变量名称
- initializer:初始化值
- trainable:是否可训练
tf.variable_scope():创建一个变量域,相当于在变量空间下打开一个文件夹。一般和tf.get_variable()组合使用,一种常用的用法如下:
import tensorflow as tfwith tf.variable_scope('cnn'): with tf.variable_scope('conv1'): w = tf.get_variable( initializer = tf.truncated_normal([3,3,3,32], stddev=0.1), trainable=True, name = 'w') b = tf.get_variable( initializer = tf.zeros([32]), trainable=True, name = 'b')print(w.name)print(b.name)
结果为
cnn/conv1/w:0cnn/conv1/b:0
实例
本章以简单的MNIST识别为例,来看看tf.get_variable()和tf.variable_scope()在训练时能带给大家怎样的方便。
模型文件
首先,创建一个py文件,专门存放生成模型的代码,叫做“cnnmodel.py”。其中定义一下权重和偏置的初始化函数:
import tensorflow as tfdef weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.get_variable(initializer = initial, trainable=True, name = 'w')def bias_variable(shape): initial = tf.zeros(shape) return tf.get_variable(initializer = initial, trainable=True, name = 'b')
接着,进一步定义卷积层、全连接层等操作,这样可以省去很多重复的字符:
def conv2d(x, W_shape): W = weight_variable(W_shape) B = bias_variable(W_shape[-1]) return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') + Bdef ann(x, W_shape): W = weight_variable(W_shape) B = bias_variable(W_shape[-1]) return tf.matmul(x, W) + Bdef max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
个人还是比较喜欢在函数里面定义变量。每次调用函数,创建的都是局部变量,即使同名也不会有冲突。
然后就是创建CNN模型了:
def cnnmodel(inp, keep_prob): with tf.variable_scope('cnn'): with tf.variable_scope('conv1'): h_conv1 = tf.nn.relu(conv2d(inp, [5, 5, 1, 32])) h_pool1 = max_pool_2x2(h_conv1) with tf.variable_scope('conv2'): h_conv2 = tf.nn.relu(conv2d(h_pool1, [5, 5, 32, 64])) h_pool2 = max_pool_2x2(h_conv2) with tf.variable_scope('fc1'): h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) h_fc1 = tf.nn.relu(ann(h_pool2_flat, [7 * 7 * 64, 1024])) h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) with tf.variable_scope('fc2'): y_conv = tf.nn.softmax(ann(h_fc1_drop, [1024, 10])) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='cnn') return y_conv, var
这样写的结果就是:在变量空间中有一个总文件夹叫“cnn”,下面有许多子文件夹“conv1”、“conv2”、“fc1”、“fc2”。每个子文件夹下都有“w”、“b”两个变量。最后的tf.get_collection()就是为了把在“cnn”目录下的变量集合起来,一步到位。是不是比自己定义列表一个一个.append()方便多了。
训练文件
此文件基本照搬Tensorflow官方教程的文档,变动不大,唯一有区别的就是在调用cnnmodel时额外输出了网络变量。
import input_datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)import tensorflow as tffrom cnnmodel import cnnmodelsess = tf.InteractiveSession()x = tf.placeholder("float", shape=[None, 784])y_ = tf.placeholder("float", shape=[None, 10])keep_prob = tf.placeholder("float")x_image = tf.reshape(x, [-1,28,28,1])y_conv, var = cnnmodel(x_image, keep_prob)cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy, var_list = var)correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))sess.run(tf.global_variables_initializer())for i in range(2000): batch = mnist.train.next_batch(50) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) if i % 500 == 0 or i == 1999: train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}) print("step %d, training accuracy %g"%(i, train_accuracy)) saver = tf.train.Saver() saver.save(sess, 'backup/latest')
导入数据
无更改,直接调用即可
from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport gzipimport osimport tempfileimport numpyfrom six.moves import urllibfrom six.moves import xrange # pylint: disable=redefined-builtinimport tensorflow as tffrom tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
重新训练 / 测试文件
在训练文件中,最后几行saver.save()把当前训练的模型导出到backup文件夹下了。该文件可以导入最后训练的网络参数,方便做一些测试。可以用这种方法导入全局的参数:
if tf.train.get_checkpoint_state('backup/'): print('\nfound\n') saver = tf.train.Saver() saver.restore(sess, 'backup/latest')
如果你只想导入某一层参数的话,之前的变量管理就帮上忙了:
var_ = tf.global_variables()net_var = [var for var in var_ if "conv1" in var.name]if tf.train.get_checkpoint_state('backup/'): print('\nfound\n') saver = tf.train.Saver(net_var) saver.restore(sess, 'backup/latest')
如果不信,可以用print(sess.run(net_var[1]))这种方法打印出参数值,来看看导入的参数和训练文件中导出的是不是一样。
要提醒一点,导入参数一定要放在sess.run(tf.global_variables_initializer())之后,否则你刚把变量值导好,一个变量初始化过来又变成预设好的初始值了。
完整文件如下:
import input_datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)import tensorflow as tffrom cnnmodel import cnnmodelsess = tf.InteractiveSession()x = tf.placeholder("float", shape=[None, 784])y_ = tf.placeholder("float", shape=[None, 10])keep_prob = tf.placeholder("float")x_image = tf.reshape(x, [-1,28,28,1])y_conv, var = cnnmodel(x_image, keep_prob)var_ = tf.global_variables()net_var = [var for var in var_ if "cnn" in var.name]sess.run(tf.global_variables_initializer())if tf.train.get_checkpoint_state('backup/'): print('\nfound\n') saver = tf.train.Saver() saver.restore(sess, 'backup/latest')
- Tensorflow 学习笔记之 共享变量(Sharing Variables)
- tensorflow之变量共享
- [Tensorflow]Sharing Variables 共享权值【tf.get_variable 和 tf.variable_scope】
- tensorflow学习笔记(五):TensorFlow变量共享和数据读取
- OpenMP Tutorial学习笔记(6)OpenMP指令之组合共享工作构造(Combined Work-Sharing)
- TensorFlow基础知识点(三)变量/Variables
- TensorFlow学习笔记3——变量共享
- TensorFlow学习-- 变量Variables/ Fetch/ Feed
- mysql源码学习笔记:系统变量variables
- OpenMP Tutorial学习笔记(5)OpenMP指令之共享工作构造(Work-Sharing)
- TensorFlow学习笔记(六)Variable变量
- tensorflow学习笔记--三(Variables: 创建,初始化,保存,和恢复)
- Spark2.1 共享变量(Broadcast Variables&Accumulators)分析。
- TensorFlow学习笔记(二):变量
- Tensorflow学习笔记-变量管理
- stylus之变量(Variables)
- tensorflow 学习笔记之 变量的一些操作
- TensorFlow笔记之变量管理
- Linux学习笔记 --邮件服务
- 设计模式之回调模式
- Linux下动态库和静态库的制作与使用
- 全程尬聊之阿里2017实习生面试
- 人机大战|深度拆解AlphaGo套路
- Tensorflow 学习笔记之 共享变量(Sharing Variables)
- opengl opengles 版本对应的时间
- 程序猿学习第七天,超链接样式和div
- 线程安全的时间类
- 字符串排列-dfs算法
- WIFI模块ESP8266的使用指南(客户端和服务器两种模式建立)
- Nginx搭建反向代理服务器过程详解
- 网页存储Web Storage
- leetcode 7. Reverse Integer