Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数
来源:互联网 发布:杭州淘宝拍摄基地 收费 编辑:程序博客网 时间:2024/05/29 03:30
这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF。
1. 背景交代
Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。我们已经在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中对该文件进行了源码解析,现在来解析一下该文件中用到的函数。
Faster-RCNN_TF/tools/train_net.py中用到的函数有以下几个:
def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义
2. 下面来分析一下上面每个函数的源码
2.1. def parse_args():
该函数在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中已经进行了解析。
2.2. get_imdb():
作用:加载训练数据。在Faster-RCNN/lib/datasetes/factory.py中被定义。
文件factory.py是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用
factory.py 源码如下:
# coding=utf-8 #有中文注释的时候,记得加上这个# --------------------------------------------------------# Faster R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Factory method for easily getting imdbs by name."""__sets = {}import datasets.pascal_vocimport datasets.imagenet3dimport datasets.kittiimport datasets.kitti_trackingimport numpy as npdef _selective_search_IJCV_top_k(split, year, top_k): """Return an imdb that uses the top k proposals from the selective search IJCV code. """ imdb = datasets.pascal_voc(split, year) imdb.roidb_handler = imdb.selective_search_IJCV_roidb imdb.config['top_k'] = top_k return imdb# Set up voc_<year>_<split> using selective search "fast" modefor year in ['2007', '2012']: for split in ['train', 'val', 'trainval', 'test']: name = 'voc_{}_{}'.format(year, split) __sets[name] = (lambda split=split, year=year: datasets.pascal_voc(split, year))"""# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode# but only returning the first k boxesfor top_k in np.arange(1000, 11000, 1000): for year in ['2007', '2012']: for split in ['train', 'val', 'trainval', 'test']: name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k) __sets[name] = (lambda split=split, year=year, top_k=top_k: _selective_search_IJCV_top_k(split, year, top_k))"""# Set up voc_<year>_<split> using selective search "fast" mode'''主要解析一下这部分,其他类似。该部分用到的数据库是pascal_voc 2007数据库,该数据库由几个部分组成,名称name分别是voc_2007_train、voc_2007_val、voc_2007_trainval、voc_2007_test,看你的任务是训练还是测试,选择相对应的数据库名称。这个数据库名称对应的就是(网络训练文件Faster-RCNN_TF/tools/train_net.py)中的参数--imdb的值,'''for year in ['2007']: for split in ['train', 'val', 'trainval', 'test']: name = 'voc_{}_{}'.format(year, split) print name __sets[name] = (lambda split=split, year=year: datasets.pascal_voc(split, year)) #这是一个lambda函数。所用的函数是datasets.pascal_voc。 #pascal_voc是一个类,在Faster-RCNN_TF/lib/datasets/pascal_voc.py中被定义 #(文件pascal_voc.py就是数据库voc_2007_train的数据读写接口)。 #datasets.pascal_voc的作用就是加载voc_2007_train数据库 #lambda函数也叫匿名函数,即,函数没有具体的名称,而用def创建的方法是有名称的。 #lambda允许用户快速定义单行函数,当然用户也可以按照典型的函数定义完成函数。 #lambda的目的就是简化用户定义使用函数的过程。# KITTI datasetfor split in ['train', 'val', 'trainval', 'test']: name = 'kitti_{}'.format(split) print name __sets[name] = (lambda split=split: datasets.kitti(split))# Set up coco_2014_<split>for year in ['2014']: for split in ['train', 'val', 'minival', 'valminusminival']: name = 'coco_{}_{}'.format(year, split) __sets[name] = (lambda split=split, year=year: coco(split, year))# Set up coco_2015_<split>for year in ['2015']: for split in ['test', 'test-dev']: name = 'coco_{}_{}'.format(year, split) __sets[name] = (lambda split=split, year=year: coco(split, year))# NTHU datasetfor split in ['71', '370']: name = 'nthu_{}'.format(split) print name __sets[name] = (lambda split=split: datasets.nthu(split))def get_imdb(name): #加载训练数据。 """Get an imdb (image database) by name.""" ''' 在Faster-RCNN_TF/tools/train_net.py中被用到。 传进来的形参name的值就是train_net.py中的args.imdb_name, 也就是train_net.py中的参数--imdb的值 参数--imdb的值,代表的是训练数据库的名字 ''' if not __sets.has_key(name): #如果没有该训练数据库的名字 raise KeyError('Unknown dataset: {}'.format(name)) #报错 return __sets[name]() #如果有该训练数据库的名字,执行__sets[name](),该函数是在本文件中(在上面)定义的def list_imdbs(): """List all registered imdbs.""" return __sets.keys()
2.3. get_training_roidb():
get_training_roidb():将训练数据变成minibatch的形式,该函数在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
def get_training_roidb(imdb): """Returns a roidb (Region of Interest database) for use in training.""" if cfg.TRAIN.USE_FLIPPED: print 'Appending horizontally-flipped training examples...' imdb.append_flipped_images() print 'done' print 'Preparing training data...' if cfg.TRAIN.HAS_RPN:#如果使用RPN(参数cfg.TRAIN.HAS_RPN在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义) if cfg.IS_MULTISCALE: gdl_roidb.prepare_roidb(imdb) else: rdl_roidb.prepare_roidb(imdb) # rdl_roidb.prepare_roidb()在Faster-RCNN_TF/lib/roi_data_layer/roidb.py中 else: rdl_roidb.prepare_roidb(imdb) print 'done' return imdb.roidb
- Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数
- Faster-RCNN Tensorflow版本源码解析(一):网络训练部分train_net.py
- py-faster-rcnn 使用的caffe sync 到最新版本
- faster rcnn 源码解析之anchor_target_layer.py
- Faster RCNN 源码解析(3.3) -- proposal_layer.py
- 用自己的数据训练Faster-RCNN,tensorflow版本(二)
- py-faster-rcnn源码解读系列(二)——pascal_voc.py
- Faster-Rcnn demo.py解析
- Faster RCNN 源码解析(3.2) -- Anchor 生成(generate_anchors.py)
- py-faster-rcnn demo.py解析
- py-faster-rcnn demo.py解析
- Tensorflow开源的object detection API中的源码解析(一):FASTER RCNN with Inception架构图
- cudnn 5.1版本下跑通 py-faster-rcnn的demo
- 配置py-faster-rcnn配到的问题
- faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py
- faster rcnn源码解读(三)train_faster_rcnn_alt_opt.py
- faster-rcnn (1):unbantu下安装 anaconda +tensorflow版本的 faster-rcnn
- py-faster-rcnn+CPU训练自己的数据集(二)
- c++ 命名空间
- 康神建议之重学《c++ primer》(2)(文件操作)
- IE8兼容性- 条件注释
- mysql存储过程while循环按时间分组查询每天总数前10
- 15 个 Android 通用流行框架大全(这篇文章好像我自己总结过的一样,发现了,于是转载了)
- Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数
- 【Angular】路由跳转问题;
- Linux中强大的说明书“man”命令
- ruby 中的类方法和实例方法
- 在matlab下计算信源熵
- KNN k近邻法tensorflow实现
- Junit使用总结
- copyonwritelist源码理解
- GKConstantNoiseSource