Tensorflow object detection API 源码阅读笔记:架构
来源:互联网 发布:手机打码赚钱软件 编辑:程序博客网 时间:2024/05/07 10:32
在之前的博文中介绍过用tf提供的预训练模型进行inference,非常简单。这里我们深入源码,了解检测API的代码架构,每个部分的深入阅读留待后续。
首先官方文档还是比较丰富的,可以先全看一遍,然后和核心的模型有关的文档是:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/defining_your_own_model.md
还有一个比较麻烦的地方是这里使用protobuf文件来管理参数配置,参见:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/configuring_jobs.md
'''构建自己模型的接口是虚基类DetectionModel,具体有5个抽象函数需要实现。'''object_detection/core/model.py def groundtruth_lists(self, field): """Access list of groundtruth tensors.""" def groundtruth_has_field(self, field): """Determines whether the groundtruth includes the given field.""" def provide_groundtruth(self, groundtruth_boxes_list, groundtruth_classes_list, groundtruth_masks_list=None, groundtruth_keypoints_list=None): """Provide groundtruth tensors.""" @abstractmethod def preprocess(self, inputs): @abstractmethod def predict(self, preprocessed_inputs) @abstractmethod def postprocess(self, prediction_dict, **params) @abstractmethod def loss(self, prediction_dict) @abstractmethod def restore_map(self, from_detection_checkpoint=True)
object_detection/meta_architectures/faster_rcnn_meta_arch.pyclass FasterRCNNFeatureExtractor(object): """Faster R-CNN Feature Extractor definition.""" def __init__(self, is_training, first_stage_features_stride, batch_norm_trainable=False, reuse_weights=None, weight_decay=0.0) @abstractmethod def preprocess(self, resized_inputs): """Feature-extractor specific preprocessing (minus image resizing).""" def extract_proposal_features(self, preprocessed_inputs, scope): """Extracts first stage RPN features.""" @abstractmethod def _extract_proposal_features(self, preprocessed_inputs, scope): def extract_box_classifier_features(self, proposal_feature_maps, scope): """Extracts second stage box classifier features.""" @abstractmethod def _extract_box_classifier_features(self, proposal_feature_maps, scope): """Extracts second stage box classifier features, to be overridden.""" def restore_from_classification_checkpoint_fn( self, first_stage_feature_extractor_scope, second_stage_feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint."""class FasterRCNNMetaArch(model.DetectionModel): """Faster R-CNN Meta-architecture definition.""" """暂时主要看哪些地方调用了feature_extractor: A FasterRCNNFeatureExtractor object.换一个cnn还是比较简单的,只需要重写一个faster_rcnn_new_cnn_feature_extractor。最终构建的检测模型是这个类的对象。""" def preprocess(self, inputs): """For Faster R-CNN, we perform image resizing in the base class --- each class subclassing FasterRCNNMetaArch is responsible for any additional preprocessing (e.g., scaling pixel values to be in [-1, 1]). 见下面代码块中实现的preprocess函数"""
object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py"""这一块和slim结合紧密,我们仔细看看。"""class FasterRCNNResnetV1FeatureExtractor( faster_rcnn_meta_arch.FasterRCNNFeatureExtractor): """Faster R-CNN Resnet V1 feature extractor implementation.""" def __init__(self, architecture, resnet_model, is_training, first_stage_features_stride, batch_norm_trainable=False, reuse_weights=None, weight_decay=0.0): def preprocess(self, resized_inputs): """Faster R-CNN Resnet V1 preprocessing.""" channel_means = [123.68, 116.779, 103.939] return resized_inputs - [[channel_means]] def _extract_proposal_features(self, preprocessed_inputs, scope): """Extracts first stage RPN features. 使用endpoints输出resnet block3的值。 """ def _extract_box_classifier_features(self, proposal_feature_maps, scope): """Extracts second stage box classifier features. 拆分出resnet的block4。注意variable_scope和arg_scope的使用。 """class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor): """Faster R-CNN Resnet 152 feature extractor implementation.""" def __init__(self, is_training, first_stage_features_stride, batch_norm_trainable=False, reuse_weights=None, weight_decay=0.0): """Constructor. Args: is_training: See base class. first_stage_features_stride: See base class. batch_norm_trainable: See base class. reuse_weights: See base class. weight_decay: See base class. Raises: ValueError: If `first_stage_features_stride` is not 8 or 16, or if `architecture` is not supported. """ super(FasterRCNNResnet152FeatureExtractor, self).__init__( 'resnet_v1_152', resnet_v1.resnet_v1_152, is_training, first_stage_features_stride, batch_norm_trainable, reuse_weights, weight_decay) """往前看各个类的init,'resnet_v1_152', resnet_v1.resnet_v1_152只用在了上面的class FasterRCNNResnetV1FeatureExtractor"""
同样建议跑一跑test脚本。会遇到如下文件,按照test中出现的顺序逐个阅读这些文件,以及对应的test脚本。
"""Builder function to construct tf-slim arg_scope for convolution, fc ops.看一下这个脚本的test,很容易理解超参数配置是怎么读取的了,类似OpenFOAM中的dict。object_detection.protos.hyperparams_pb2.Hyperparams。"""from object_detection.builders import hyperparams_builder"""Contains routines for printing protocol messages in text format.同样是上面这个test脚本,目前主要用在 conv_hyperparams_proto = hyperparams_pb2.Hyperparams()text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)其中conv_hyperparams_text_proto是包含参数配置的字符串,conv_hyperparams_proto是hyperparams.proto object,hyperparams_builder.build的第一个参数。"""from google.protobuf import text_format"""Function to build box predictor from configuration.Box predictors are classes that take a high levelimage feature map as input and produce two predictions,(1) a tensor encoding box locations, and(2) a tensor encoding classes for each box.object_detection/core/box_predictor.py留待后续研读。注意conv_hyperparams_text_proto是放进box_predictor_text_proto然后一起传递给class ConvolutionalBoxPredictor(BoxPredictor)的。"""from object_detection.builders import box_predictor_builder"""Generates grid anchors on the fly as used in Faster RCNN.下次细看。"""from object_detection.anchor_generators import grid_anchor_generator"""Builder function for post processing operations."""from object_detection.builders import post_processing_builder"""Classification and regression loss functions for object detection."""from object_detection.core import losses"""proto文件,下次再结合相应的core和builder来具体研究如何编写和读取这些文件"""from object_detection.protos import box_predictor_pb2from object_detection.protos import hyperparams_pb2from object_detection.protos import post_processing_pb2
"""A function to build a DetectionModel from configuration.很多内容在faster_rcnn_meta_arch_test_lib.py测试过了。"""object_detection/builders/model_builder.py
阅读全文
0 0
- Tensorflow object detection API 源码阅读笔记:架构
- Tensorflow object detection API 源码阅读笔记:RPN
- Tensorflow object detection API 源码阅读笔记:Mask R-CNN
- Tensorflow object detection API 源码阅读笔记:Fast r-cnn
- Tensorflow object detection API 源码阅读笔记:RFCN
- Tensorflow object detection API 源码阅读笔记:基本类(1)
- TensorFlow Object Detection API
- Tensorflow Object Detection API
- TensorFlow Object Detection API
- 测试TensorFlow Object Detection API
- tensorflow object detection API安装
- TensorFlow Object Detection API 介绍
- 安装 Tensorflow Object Detection API
- TensorFlow Object Detection API 实践
- 修改TensorFlow Object Detection API
- TensorFlow Object Detection API 教程
- tensorflow object detection API安装
- Tensorflow Object Detection API使用
- Latex基本表格绘制
- 搭建tensorflow进行cnn图像识别
- AQS源码分析
- 十四周
- DFS
- Tensorflow object detection API 源码阅读笔记:架构
- 编程的自学方法
- 软件开发知识体系
- 表单-1
- HDU 2054 A==B?(java)
- 一起艳学Springboot开发微信公众号(一)
- Windows和linux平台安装tensorflow附安装资源与脚本一键安装
- 配置yum 源
- 对const volatile修饰符共同修改变量的解释