使用Tensorflow的slim库进行迁移学习

来源:互联网 发布:笔记本分享wifi软件 编辑:程序博客网 时间:2024/06/06 05:38

由于slim库不是tf的核心库,因此需要到github下载相关代码,这里假设我的工作目录为:/home/hiptonese/MigrationLearning

  • 1 下载代码:https://github.com/tensorflow/models
  • 2 将下载好的代码放到工作目录下
  • 3 下载你所需要的模型的checkpoint文件(该文件存放了模型预训练的变量值),这里列出了各个常用模型的ckpt文件:https://github.com/tensorflow/models/tree/master/research/slim#Pretrained
  • 4 加载代码和图片文件,这里给出例子:
'''@Date  : 2017-11-21 19:18@Author: yangyang Deng@Email : yangydeng@163.com'''import osimport tensorflow as tffrom models.research.slim.datasets import imagenetfrom models.research.slim.preprocessing import inception_preprocessingimport numpy as np# 工程的根目录,同时也是ckpt所在的目录checkpoints_dir = '/home/hiptonese/MigrationLearning/'slim = tf.contrib.slimimage_size = 299with tf.Graph().as_default():    with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope()):        # 加载一张图片        imgPath = 'ship.jpeg'        testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()        testImage = tf.image.decode_jpeg(testImage_string, channels=3)        processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)        processed_images = tf.expand_dims(processed_image, 0)        # 这里如果我们设置num_classes=None,则可以得到restnet输出的瓶颈层,num_classes默认为10001,是用作imagenet的输出层。同样,我们也可以根据需要修改num_classes为其他的值来满足我们的训练要求。        final_point, endpoints = inception_resnet_v2.inception_resnet_v2(processed_images, num_classes=None, is_training=False)        init_fn = slim.assign_from_checkpoint_fn(os.path.join(checkpoints_dir, 'inception_resnet_v2_2016_08_30.ckpt'),slim.get_model_variables('InceptionResnetV2'))        with tf.Session() as sess:            init_fn(sess)            final_point_eval = np.array(sess.run(final_point))            print(final_point_eval.shape)

× 最后解释一下“瓶颈层”(bottle neck layer)的含义:
瓶颈层一般指网络结束卷基层,将要进入全连层的输入。由于网络中的变量已经做了预训练,因此瓶颈层的输出可以看做是对原始图片的进一步特征提取。因此这里如果将瓶颈层作为输入,后面只需要自己加入FC全连层,则可以不在参数调整和训练上花太多时间,快速达到较好的效果。

原创粉丝点击