TF-Slim简介
来源:互联网 发布:js selected 编辑:程序博客网 时间:2024/04/27 15:18
slim作为一种轻量级的tensorflow库,使得模型的构建,训练,测试都变得更加简单。
使用方法:
import tensorflow.contrib.slim as slim
组成部分:
arg_scope: 使得用户可以在同一个arg_scope中使用默认的参数
data,evaluation,layers,learning,losses,metrics,nets,queues,regularizers,variables
定义模型
在slim中,组合使用variables, layers和scopes可以简洁的定义模型。
(1)variables: 定义于variables.py。生成一个weight变量
, 用truncated normal初始化它, 并使用l2正则化,并将其放置于
CPU上
, 只需下面的代码即可:
weights = slim.variable('weights', shape=[10, 10, 3 , 3],
initializer=tf.truncated_normal_initializer(stddev=0.1),
regularizer=slim.l2_regularizer(0.05),
device='/CPU:0')
原生tensorflow包含两类变量:普通变量和局部变量。大部分变量都是普通变量,它们一旦生成就可以通过使用saver存入硬盘,局部变量只在session中存在,不会保存。
slim进一步的区分了变量类型,定义了model variables,这种变量代表了模型的参数。模型变量通过训练活着微调而得到学习,或者在评测或前向传播中可以从ckpt文件中载入。
非模型参数在实际前向传播中不需要的参数,比如global_step。同样的,移动平均反应了模型参数,但它本身不是模型参数。例子见下:
# Model Variables weights = slim.model_variable('weights', shape=[10, 10, 3 , 3], initializer=tf.truncated_normal_initializer(stddev=0.1), regularizer=slim.l2_regularizer(0.05), device='/CPU:0') model_variables = slim.get_model_variables() # model_variables包含weights # Regular variables my_var = slim.variable('my_var', shape=[20, 1], initializer=tf.zeros_initializer()) regular_variables_and_model_variables = slim.get_variables()
#get_variables()得到模型参数和常规参数
当我们通过slim的layers或着直接使用slim.model_variable创建变量时,tf会将此变量加入tf.GraphKeys.MODEL_VARIABLES这个集合中,当你需要构建自己的变量时,可以通过以下代码
将其加入模型参数。
my_model_variable = CreateViaCustomCode() # Letting TF-Slim know about the additional variable. slim.add_model_variable(my_model_variable)
(2)layers:抽象并封装了常用的层,并且提供了repeat和stack操作,使得定义网络更加方便。
下面的代码利用repeat实现了三个卷积层的堆叠
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
repeat不仅只实现了相同操作相同参数的重复,它还将scope进行了展开,例子中的scope被展开为 'conv3/conv3_1', 'conv3/conv3_2' and 'conv3/conv3_3'。
slim.stack
操作使得我们可以重复的讲同一个操作以不同参数一次作用于一些层,这些层的输入输出时串联起来的。比如:
slim.stack(x, slim.fully_connected, [32, 64, 128], scope='fc')
slim.stack(x, slim.conv2d, [(32, [3, 3]), (32, [1, 1]), (64, [3, 3]), (64, [1, 1])], scope='core')
(3)scopes:除了tensorflow中的name_scope和variable_scope, tf.slim新增了arg_scope操作,这一操作符可以让定义在这一scope中的操作共享参数,即如不制定参数的话,则使用默认参数。且参数可以被局部覆盖。使得代码更加简洁,如下:
with slim.arg_scope([slim.conv2d], padding='SAME', weights_initializer=tf.truncated_normal_initializer(stddev=0.01) weights_regularizer=slim.l2_regularizer(0.0005)): net = slim.conv2d(inputs, 64, [11, 11], scope='conv1') net = slim.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2') net = slim.conv2d(net, 256, [11, 11], scope='conv3')
而且,我们也可以嵌套多个arg_scope在其中使用多个操作。
训练模型
Tensorflow的模型训练需要模型,损失函数,梯度计算,以及根据loss的梯度迭代更新参数。
(1)losses
使用现有的loss:
loss = slim.losses.softmax_cross_entropy(predictions, labels)
对于多任务学习的loss,可以使用:
# Define the loss functions and get the total loss.classification_loss = slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)sum_of_squares_loss = slim.losses.sum_of_squares(depth_predictions, depth_labels)# The following two lines have the same effect:total_loss = classification_loss + sum_of_squares_losstotal_loss = slim.losses.get_total_loss(add_regularization_losses=False)
如果使用了自己定义的loss,而又想使用slim的loss管理机制,可以使用:
pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)slim.losses.add_loss(pose_loss)
total_loss = slim.losses.get_total_loss()
#total_loss中包涵了pose_loss
(2) 训练循环
slim在learning.py中提供了一个简单而有用的训练模型的工具。我们只需调用slim.learning.create_train_op
和slim.learning.train就可以完成优化过程。
g = tf.Graph()# Create the model and specify the losses......total_loss = slim.losses.get_total_loss()optimizer = tf.train.GradientDescentOptimizer(learning_rate)# create_train_op ensures that each time we ask for the loss, the update_ops# are run and the gradients being computed are applied too.train_op = slim.learning.create_train_op(total_loss, optimizer)logdir = ... # Where checkpoints are stored.slim.learning.train( train_op, logdir, number_of_steps=1000,#迭代次数 save_summaries_secs=300,#存summary间隔秒数 save_interval_secs=600)#存模型建个秒数
(3)训练的例子:
import tensorflow as tfslim = tf.contrib.slimvgg = tf.contrib.slim.nets.vgg...train_log_dir = ...if not tf.gfile.Exists(train_log_dir): tf.gfile.MakeDirs(train_log_dir)with tf.Graph().as_default(): # Set up the data loading: images, labels = ... # Define the model: predictions = vgg.vgg16(images, is_training=True) # Specify the loss function: slim.losses.softmax_cross_entropy(predictions, labels) total_loss = slim.losses.get_total_loss() tf.summary.scalar('losses/total_loss', total_loss) # Specify the optimization scheme: optimizer = tf.train.GradientDescentOptimizer(learning_rate=.001) # create_train_op that ensures that when we evaluate it to get the loss, # the update_ops are done and the gradient updates are computed. train_tensor = slim.learning.create_train_op(total_loss, optimizer) # Actually runs training. slim.learning.train(train_tensor, train_log_dir)
根据已有模型进行微调
(1)利用tf.train.Saver()从checkpoint恢复模型
# Create some variables.v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# Add ops to restore all the variables.restorer = tf.train.Saver()# Add ops to restore some variables.restorer = tf.train.Saver([v1, v2])# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Do some work with the model ...
(2)部分恢复模型参数
# Create some variables.v1 = slim.variable(name="v1", ...)v2 = slim.variable(name="nested/v2", ...)...# Get list of variables to restore (which contains only 'v2'). These are all# equivalent methods:variables_to_restore = slim.get_variables_by_name("v2")# orvariables_to_restore = slim.get_variables_by_suffix("2")# orvariables_to_restore = slim.get_variables(scope="nested")# orvariables_to_restore = slim.get_variables_to_restore(include=["nested"])# orvariables_to_restore = slim.get_variables_to_restore(exclude=["v1"])# Create the saver which will be used to restore the variables.restorer = tf.train.Saver(variables_to_restore)with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt") print("Model restored.") # Do some work with the model ...
(3)当图的变量名与checkpoint中的变量名不同时,恢复模型参数
当从checkpoint文件中恢复变量时,Saver在checkpoint文件中定位到变量名,并且把它们映射到当前图中的变量中。之前的例子中,我们创建了Saver,并为其提供了变量列表作为参数。这时,在checkpoint文件中定位的变量名,是隐含地从每个作为参数给出的变量的var.op.name而获得的。这一方式在图与checkpoint文件中变量名字相同时,可以很好的工作。而当名字不同时,必须给Saver提供一个将checkpoint文件中的变量名映射到图中的每个变量的字典,例子见下:
# Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights'def name_in_checkpoint(var): return 'vgg16/' + var.op.name# Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'def name_in_checkpoint(var): if "weights" in var.op.name: return var.op.name.replace("weights", "params1") if "bias" in var.op.name: return var.op.name.replace("bias", "params2")variables_to_restore = slim.get_model_variables()variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}restorer = tf.train.Saver(variables_to_restore)with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt")
(4)在一个不同的任务上对网络进行微调
比如我们要将1000类的imagenet分类任务应用于20类的Pascal VOC分类任务中,我们只导入部分层,见下例:
image, label = MyPascalVocDataLoader(...)images, labels = tf.train.batch([image, label], batch_size=32)# Create the modelpredictions = vgg.vgg_16(images)train_op = slim.learning.create_train_op(...)# Specify where the Model, trained on ImageNet, was saved.model_path = '/path/to/pre_trained_on_imagenet.checkpoint'# Specify where the new model will live:log_dir = '/path/to/my_pascal_model_dir/'# Restore only the convolutional layers:variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)# Start training.slim.learning.train(train_op, log_dir, init_fn=init_fn)
- TF-Slim简介
- TF-slim
- tensorflow slim【TF-Slim】
- TF-slim学习
- tf.contrib.slim
- tf.contrib.slim
- Tensorflow之TF-Slim
- tf.slim使用方法
- TF-slim DCGAN
- TF-Slim学习(1)
- TF-slim快速搭建cnn
- TF-slim快速搭建DNN
- 【Tensorflow】辅助工具篇——tensorflow slim(TF-Slim)介绍
- 解决使用tf.slim找不到slim.utils函数问题
- Tensorflow辅助工具篇——tensorflow slim(TF-Slim)介绍
- 使用 TF-Slim 设计复杂网络
- TF-slim download_and_convert_flowers.py代码解析
- 应用TF-Slim快速实现迁移学习
- svn 命令行创建和删除 分支和tags
- 一、JDBC常用的接口
- 无法访问 private 成员(在“std::basic_ios<_Elem,_Traits>”类中声明
- appium genymotion URLError: <urlopen error [Errno 10061] >
- CentOS常用命令
- TF-Slim简介
- 分子量UVa1586
- Yocto项目(imx6dl处理器)的使用
- python入门之初体验-----A Byte of Python3(观后感)
- Spring注解
- Android开发者如何搭建服务器
- iperf交叉编译运行出现illegal instruction问题
- java 实现 Hive 导入到 mysq
- vim 设置显示行号