Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数

来源:互联网 发布:杭州淘宝拍摄基地 收费 编辑:程序博客网 时间:2024/05/29 03:30

这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF。

1. 背景交代

Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。我们已经在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中对该文件进行了源码解析,现在来解析一下该文件中用到的函数。

Faster-RCNN_TF/tools/train_net.py中用到的函数有以下几个:

def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义

2. 下面来分析一下上面每个函数的源码

2.1. def parse_args():

该函数在Faster-RCNN Tensorflow版本源码解析(一)网络训练部分中已经进行了解析。

2.2. get_imdb():

作用:加载训练数据。在Faster-RCNN/lib/datasetes/factory.py中被定义。
文件factory.py是个工厂类,用类生成imdb类并且返回数据库供网络训练和测试使用
factory.py 源码如下:

# coding=utf-8 #有中文注释的时候,记得加上这个# --------------------------------------------------------# Faster R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Factory method for easily getting imdbs by name."""__sets = {}import datasets.pascal_vocimport datasets.imagenet3dimport datasets.kittiimport datasets.kitti_trackingimport numpy as npdef _selective_search_IJCV_top_k(split, year, top_k):    """Return an imdb that uses the top k proposals from the selective search    IJCV code.    """    imdb = datasets.pascal_voc(split, year)    imdb.roidb_handler = imdb.selective_search_IJCV_roidb    imdb.config['top_k'] = top_k    return imdb# Set up voc_<year>_<split> using selective search "fast" modefor year in ['2007', '2012']:    for split in ['train', 'val', 'trainval', 'test']:        name = 'voc_{}_{}'.format(year, split)        __sets[name] = (lambda split=split, year=year:                datasets.pascal_voc(split, year))"""# Set up voc_<year>_<split>_top_<k> using selective search "quality" mode# but only returning the first k boxesfor top_k in np.arange(1000, 11000, 1000):    for year in ['2007', '2012']:        for split in ['train', 'val', 'trainval', 'test']:            name = 'voc_{}_{}_top_{:d}'.format(year, split, top_k)            __sets[name] = (lambda split=split, year=year, top_k=top_k:                    _selective_search_IJCV_top_k(split, year, top_k))"""# Set up voc_<year>_<split> using selective search "fast" mode'''主要解析一下这部分,其他类似。该部分用到的数据库是pascal_voc 2007数据库,该数据库由几个部分组成,名称name分别是voc_2007_train、voc_2007_val、voc_2007_trainval、voc_2007_test,看你的任务是训练还是测试,选择相对应的数据库名称。这个数据库名称对应的就是(网络训练文件Faster-RCNN_TF/tools/train_net.py)中的参数--imdb的值,'''for year in ['2007']:    for split in ['train', 'val', 'trainval', 'test']:        name = 'voc_{}_{}'.format(year, split)        print name        __sets[name] = (lambda split=split, year=year:                datasets.pascal_voc(split, year)) #这是一个lambda函数。所用的函数是datasets.pascal_voc。        #pascal_voc是一个类,在Faster-RCNN_TF/lib/datasets/pascal_voc.py中被定义        #(文件pascal_voc.py就是数据库voc_2007_train的数据读写接口)。        #datasets.pascal_voc的作用就是加载voc_2007_train数据库    #lambda函数也叫匿名函数,即,函数没有具体的名称,而用def创建的方法是有名称的。    #lambda允许用户快速定义单行函数,当然用户也可以按照典型的函数定义完成函数。    #lambda的目的就是简化用户定义使用函数的过程。# KITTI datasetfor split in ['train', 'val', 'trainval', 'test']:    name = 'kitti_{}'.format(split)    print name    __sets[name] = (lambda split=split:            datasets.kitti(split))# Set up coco_2014_<split>for year in ['2014']:    for split in ['train', 'val', 'minival', 'valminusminival']:        name = 'coco_{}_{}'.format(year, split)        __sets[name] = (lambda split=split, year=year: coco(split, year))# Set up coco_2015_<split>for year in ['2015']:    for split in ['test', 'test-dev']:        name = 'coco_{}_{}'.format(year, split)        __sets[name] = (lambda split=split, year=year: coco(split, year))# NTHU datasetfor split in ['71', '370']:    name = 'nthu_{}'.format(split)    print name    __sets[name] = (lambda split=split:            datasets.nthu(split))def get_imdb(name):  #加载训练数据。    """Get an imdb (image database) by name."""    '''    在Faster-RCNN_TF/tools/train_net.py中被用到。    传进来的形参name的值就是train_net.py中的args.imdb_name,    也就是train_net.py中的参数--imdb的值    参数--imdb的值,代表的是训练数据库的名字    '''    if not __sets.has_key(name): #如果没有该训练数据库的名字        raise KeyError('Unknown dataset: {}'.format(name)) #报错    return __sets[name]() #如果有该训练数据库的名字,执行__sets[name](),该函数是在本文件中(在上面)定义的def list_imdbs():    """List all registered imdbs."""    return __sets.keys()

2.3. get_training_roidb():

get_training_roidb():将训练数据变成minibatch的形式,该函数在Faster-RCNN/lib/fast_rcnn/train.py中被定义。

def get_training_roidb(imdb):    """Returns a roidb (Region of Interest database) for use in training."""    if cfg.TRAIN.USE_FLIPPED:        print 'Appending horizontally-flipped training examples...'        imdb.append_flipped_images()        print 'done'    print 'Preparing training data...'    if cfg.TRAIN.HAS_RPN:#如果使用RPN(参数cfg.TRAIN.HAS_RPN在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义)        if cfg.IS_MULTISCALE:            gdl_roidb.prepare_roidb(imdb)        else:            rdl_roidb.prepare_roidb(imdb) # rdl_roidb.prepare_roidb()在Faster-RCNN_TF/lib/roi_data_layer/roidb.py中    else:        rdl_roidb.prepare_roidb(imdb)    print 'done'    return imdb.roidb
阅读全文
0 0
原创粉丝点击