Faster-RCNN Tensorflow版本源码解析(一):网络训练部分train_net.py
来源:互联网 发布:voc0712数据集 编辑:程序博客网 时间:2024/06/05 04:56
这里将要解析的是Faster-RCNN Tensorflow版本,fork自githubFaster-RCNN_TF。
网络训练部分
Faster-RCNN_TF中,网络的训练文件是 Faster-RCNN_TF/tools/train_net.py。
1. 启动训练的方法
我们在启动faster-RCNN网络训练的时候,要在目录Faster-RCNN_TF/下,在终端输入:
python ./tools/train_net.py —device gpu —device_id 3 —solver VGG_CNN_M_1024 —weight ./data/pretrain_model/VGG_imagenet.npy —imdb voc_2007_train —network voc2007_train
上述例子中,所使用的数据库是voc_2007_train,我们就以voc_2007_train为例进行说明。
参数解释:(注意:所有的—都是两个小横杠,不是一个)
train_net.py: 是网络的训练文件
—device :代表选用cpu还是gpu
—device_id: 代表机器上的cpu或者gpu的编号,根据自己的机器可自行修改
—solver: 模型的配置文件,这个参数就不要进行修改了,固定就是VGG_CNN_M_1024
—weight: 初始化的权重文件,这里用的是Imagenet上预训练好的模型VGG_imagenet.npy,存放在目录Faster-RCNN_TF/data/pretrain_model下。如果没有该目录,自己手动创建一个;VGG_imagenet.npy如果没有,自行下载,原工程提供了下载链接
—imdb: 训练的数据库名称
—network: 代表选择训练网络还是测试网络,这个参数的值的形式是固定的,必须是voc2007_train的形式,前半部分voc2007可以随便(但是不能有下划线),后半部分必须是_train
训练完成之后的模型默认保存在了目录Faster-RCNN_TF/output/default/voc_2007_train/下(该目录如果没有,程序会自动创建一个,所以不用自己手动创建)。
2. Faster-RCNN_TF/tools/train_net.py 源码解析
# coding=utf-8 #有中文注释,记得加上这个#!/usr/bin/env python# --------------------------------------------------------# Faster R-CNN# Copyright (c) 2015 Microsoft# Licensed under The MIT License [see LICENSE for details]# Written by Ross Girshick# --------------------------------------------------------"""Train a Faster R-CNN network on a region of interest database."""#首先,加载进来需要的各种模块。#python中,每个py文件被称之为模块,,每个具有__init__.py文件的目录被称为包import _init_paths #_init_paths是一个.py文件,用来设置Faster-RCNN的路径from fast_rcnn.train import get_training_roidb, train_netfrom fast_rcnn.config import cfg,cfg_from_file, cfg_from_list, get_output_dirfrom datasets.factory import get_imdbfrom networks.factory import get_networkimport argparseimport pprintimport numpy as npimport sysimport osdef parse_args(): """ Parse input arguments解析输入参数 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数 """ parser = argparse.ArgumentParser(description='Train a Faster R-CNN network') parser.add_argument('--device', dest='device', help='device to use', default='cpu', type=str) #--device代表选用cpu还是gpu,默认cpu parser.add_argument('--device_id', dest='device_id', help='device id to use', default=0, type=int) #--device_id代表机器上的cpu或者gpu的编号 parser.add_argument('--solver', dest='solver', help='solver prototxt', default=None, type=str) # --solver代表模型的配置文件,这个参数的值就固定是VGG_CNN_M_1024 parser.add_argument('--iters', dest='max_iters', help='number of iterations to train', default=70000, type=int) #--iters代表训练时的最大迭代步数,默认是70000步 parser.add_argument('--weights', dest='pretrained_model', help='initialize with pretrained model weights', default=None, type=str) #--weights代表权重文件,也就是预训练好的模型。 #这里用的是Imagenet上预训练好的模型VGG_imagenet.npy, #存放在目录Faster-RCNN_TF/data/pretrain_model下 parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default=None, type=str) parser.add_argument('--imdb', dest='imdb_name', help='dataset to train on', default='kitti_train', type=str) #--imdb代表训练数据库的名称,默认是kitti_train。 #该工程中,提供了5种数据库来训练网络,并分别给出了各自的数据读写接口, #5种数据库分别是pascal_voc,coco,kitti,nissan,nthu #(工程中,说是提供了5种数据库,但是也就只给出了各自的数据库读写接口,并没有给出实际的数据库,所以得需要自己另行下载,工程中没有提供)。 #另外,这个数据库名称是固定的,该名称在Faster-RCNN_TF/lib/datasets/factory.py中 #被定义了具体的格式:以pascal_voc数据库为例,参数--imdb的值应为voc_2007_train。 #文件factory.py会在下一篇做更进一步的解释 parser.add_argument('--rand', dest='randomize', help='randomize (do not use a fixed seed)', action='store_true') parser.add_argument('--network', dest='network_name', help='name of the network', default='kitti_train', type=str) #--network代表选择训练网络还是测试网络, #这个参数的值的形式是固定的,必须是kitti_train的形式, #前半部分kitti可以随便定义(但是不能有下划线),后半部分必须是_train parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) if len(sys.argv) == 1: parser.print_help() sys.exit(1) args = parser.parse_args() return argsif __name__ == '__main__': #主函数 args = parse_args() print('Called with args:') print(args) if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) print('Using config:') pprint.pprint(cfg) #cfg就是Faster-RCNN_TF/lib/fast_rcnn/config.py, #是网络训练的参数文件。这里的参数指的是网络在训练过程需要用到的各种参数。 if not args.randomize: # fix the random seeds (numpy and caffe) for reproducibility np.random.seed(cfg.RNG_SEED) #加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。 imdb = get_imdb(args.imdb_name) print 'Loaded dataset `{:s}` for training'.format(imdb.name) #将训练数据变成minibatch的形式。 #函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义 roidb = get_training_roidb(imdb) #设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。 #函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义 output_dir = get_output_dir(imdb, None) print 'Output will be saved to `{:s}`'.format(output_dir) #设置GPU或者CPU的 id os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id) device_name = '/{}:{:d}'.format(args.device,args.device_id) print device_name #按照args.network_name获取网络。选择train网络或者test网络。 #为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。 #函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。 network = get_network(args.network_name) print 'Use network `{:s}` in training'.format(args.network_name) #启动Faster-RCNN网络训练。 #函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义 train_net(network, imdb, roidb, output_dir, pretrained_model=args.pretrained_model, max_iters=args.max_iters)
上述网络的训练文件Faster-RCNN_TF/tools/train_net.py中,用到的函数有以下几个:
def parse_args():解析输入参数。 这里的参数指的是,在运行train_net.py这个文件时,需要的输入参数。该函数的定义就在Faster-RCNN_TF/tools/train_net.py中。
get_imdb():加载训练数据。函数get_imdb在Faster-RCNN/lib/datasetes/factory.py中被定义。
get_training_roidb():将训练数据变成minibatch的形式。函数get_training_roidb在Faster-RCNN/lib/fast_rcnn/train.py中被定义。
get_output_dir():设置保存(训练好的模型)的目录。如果该目录没有,会自动新建一个。函数get_output_dir在Faster-RCNN_TF/lib/fast_rcnn/config.py中被定义。
get_network():按照args.network_name获取网络。选择train网络或者test网络。为什么参数args.network_name的值有固定的格式,看函数get_network就知道了。函数get_network在Faster-RCNN_TF/lib/networks/factory.py中被定义。
train_net():启动Faster-RCNN网络训练。函数train_net在Faster-RCNN_TF/lib/fast_rcnn/train.py中被定义
下一篇来分析一下上面每个函数的源码
- Faster-RCNN Tensorflow版本源码解析(一):网络训练部分train_net.py
- Faster-RCNN Tensorflow版本源码解析(二)train_net.py所用到的函数
- 训练py-faster-rcnn
- 用自己的数据训练Faster-RCNN,tensorflow版本(一)
- faster rcnn 源码解析之anchor_target_layer.py
- Faster RCNN 源码解析(3.3) -- proposal_layer.py
- 训练py-faster-rcnn(caffe)
- py-faster-rcnn训练教程
- faster-rcnn 之训练脚本解析:./tools/train_faster_rcnn_alt_opt.py
- py-faster-rcnn源码解读系列(一)——train_faster_rcnn_alt_opt.py
- py-faster-rcnn流程(2)——训练RPN网络一阶段
- py-faster-rcnn流程(4)——训练FastRCNN网络一阶段
- py-faster-rcnn流程(5)——训练RPN网络二阶段
- py-faster-rcnn流程(6)——训练Fastrcnn网络二阶段
- Faster RCNN训练(Matlab版本)结果
- py-faster-rcnn+CPU训练自己的数据集(一)
- Faster-Rcnn demo.py解析
- Faster RCNN 源码解析(3.2) -- Anchor 生成(generate_anchors.py)
- HTML中DOM解析篇1--nodeType\nodeValue\nodeName
- oracle delete、truncate、drop语句区别
- makefile编写
- No active profile set, falling back to default profiles: default
- PCF8575扩展SOC端口很方便
- Faster-RCNN Tensorflow版本源码解析(一):网络训练部分train_net.py
- 利用history对象实现地址栏更新,页面局部刷新
- java 读取 txt文件 特定行
- 关于linux删除文件夹命令方法
- maven私服的搭建
- INFO [Timer-282] org.apache.catalina.loader.WebappClassLoaderBase.checkStateForResourceLoading.....
- 中间件的实现原理
- Java面向对象-String类作业一字符串转数组
- Nginx为什么比Apache Httpd高效:原理篇