tf-faster-rcnn代码理解之trianval_net.py
来源:互联网 发布:里欧万塔 知乎 编辑:程序博客网 时间:2024/06/06 15:10
原始工程代码是通过tf-faster-rcnn\experiments\scripts目录下的train_faster_rcnn.sh调用tf-faster-rcnn\tools\trainval_net.py进行模型训练。为了方便使用pycharm对整个训练工程进行调试,故修改trianval_net.py使之不需要shell脚本引导,可以直接运行。修改之后的代码如下:
# --------------------------------------------------------# Tensorflow Faster R-CNN# Licensed under The MIT License [see LICENSE for details]# Written by Zheqi He, Xinlei Chen, based on code from Ross Girshick# --------------------------------------------------------from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport _init_pathsfrom model.train_val import get_training_roidb, train_netfrom model.config import cfg, cfg_from_file, cfg_from_list, get_output_dir, get_output_tb_dirfrom datasets.factory import get_imdbimport datasets.imdbimport argparseimport pprintimport numpy as npimport sysimport tensorflow as tffrom nets.vgg16 import vgg16from nets.resnet_v1 import resnetv1class args: """ Parse input arguments """ cfg_file = '/home/whao/tf-faster-rcnn/experiments/cfgs/vgg16.yml' weight = '/home/whao/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' imdb_name = 'voc_2007_trainval' imdbval_name = 'voc_2007_test' max_iters = 100000 tag = None net = 'vgg16' set_cfgs = ['ANCHOR_SCALES', '[8,16,32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'TRAIN.STEPSIZE', '50000']def combined_roidb(imdb_names): """ Combine multiple roidbs """ def get_roidb(imdb_name): imdb = get_imdb(imdb_name) print('Loaded dataset `{:s}` for training'.format(imdb.name)) imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD) print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)) roidb = get_training_roidb(imdb) return roidb roidbs = [get_roidb(s) for s in imdb_names.split('+')] roidb = roidbs[0] if len(roidbs) > 1: for r in roidbs[1:]: roidb.extend(r) tmp = get_imdb(imdb_names.split('+')[1]) imdb = datasets.imdb.imdb(imdb_names, tmp.classes) else: imdb = get_imdb(imdb_names) return imdb, roidbif __name__ == '__main__':# args = parse_args() print('Called with args:') if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) print('Using config:') pprint.pprint(cfg) np.random.seed(cfg.RNG_SEED) # train set imdb, roidb = combined_roidb(args.imdb_name) print('{:d} roidb entries'.format(len(roidb))) # output directory where the models are saved output_dir = get_output_dir(imdb, args.tag) print('Output will be saved to `{:s}`'.format(output_dir)) # tensorboard directory where the summaries are saved during training tb_dir = get_output_tb_dir(imdb, args.tag) print('TensorFlow summaries will be saved to `{:s}`'.format(tb_dir)) # also add the validation set, but with no flipping images orgflip = cfg.TRAIN.USE_FLIPPED cfg.TRAIN.USE_FLIPPED = False _, valroidb = combined_roidb(args.imdbval_name) print('{:d} validation roidb entries'.format(len(valroidb))) cfg.TRAIN.USE_FLIPPED = orgflip # load network if args.net == 'vgg16': net = vgg16(batch_size=cfg.TRAIN.IMS_PER_BATCH) else: raise NotImplementedError train_net(net, imdb, roidb, valroidb, output_dir, tb_dir, pretrained_model=args.weight, max_iters=args.max_iters)以上代码中,定义了args类代替shell传参。首先需要把训练集按照pascal voc的格式处理好,包括文件名与标签个事和Main中的txt的文件。
代码的执行流程是先读取cfg_file所指定的yml文件来配置部分超参量。执行函数为cfg_from_file(args.cfg_file),它把yml中的超参数合并到config.py中定义的__C对象中,它是类EasyDict的对象。
然后,通过cfg_from_list(args.set_cfgs)配置__C对象中的变量。
接下来,开始处理训练集,通过combined_roidb(args.imdb_name)收集训练集,它通过调用lib/datasets/factory.py中的get_imdb()获得数据集,获得类pascal_voc的对象imdb,再设置区域推荐的方式,默认为gt,通过lib/model/train_val.py中的函数get_training_roidb()获得roidb,即每张图片中的区域推荐样本,其为实际为imdb中的一个变量。打印出区域推荐样本的数量
接下来设置训练好的模型和tensorboard文件的存储路径,再获取验证集的数据,前面的训练的数据是经过数据增强的,每张图片都经过旋转,验证集不进行数据增强。
接下来,配置vgg16网络的batch数量,默认是设置为1。
最后调用train_val.py中的train_net()函数开启训练。
未完待续。
- tf-faster-rcnn代码理解之trianval_net.py
- tf-faster-rcnn代码理解
- tf-faster-rcnn代码理解之获取数据集对象imdb,roidb,valroidb
- 跑py-faster-rcnn代码
- 【py-faster-rcnn】【RPN】通过代码理解faster-RCNN中的RPN
- Faster R-CNN:tf-faster-rcnn代码结构
- faster rcnn 源码解析之anchor_target_layer.py
- py-faster-rcnn代码roidb.py的解读
- Faster RCNN roidb.py
- Faster RCNN train_faster_rcnn_alt_opt.py
- Faster RCNN layer.py
- Faster RCNN train.py
- Faster RCNN generate.py
- Faster RCNN blob.py
- Faster RCNN minibatch.py
- Faster RCNN pascal_voc.py
- Faster RCNN imdb.py
- Faster RCNN anchor_target_layer.py
- virtualbox rc=-101 问题解决
- Python机器学习之XGBoost从入门到实战(代码实现)
- 奇异值分解
- 机房收费系统之实时错误426
- leetcode Unique Paths && Unique Paths
- tf-faster-rcnn代码理解之trianval_net.py
- Dropout浅层理解与实现
- python爬虫scrapy运行ImportError: No module named win32api错误
- linux系统之apache基本配置及语言支持及虚拟主机 访问控制
- STK 11.3 新特性(New Features)
- 拦截器初级入门
- a = a + 1与a + = 1的区别
- 分布式-微服务-集群的区别
- ul,ol