Tensorflow object detection API 源码阅读笔记:架构

来源:互联网 发布:手机打码赚钱软件 编辑:程序博客网 时间:2024/05/07 10:32

在之前的博文中介绍过用tf提供的预训练模型进行inference,非常简单。这里我们深入源码,了解检测API的代码架构,每个部分的深入阅读留待后续。

首先官方文档还是比较丰富的,可以先全看一遍,然后和核心的模型有关的文档是:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/defining_your_own_model.md
还有一个比较麻烦的地方是这里使用protobuf文件来管理参数配置,参见:
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/configuring_jobs.md

'''构建自己模型的接口是虚基类DetectionModel,具体有5个抽象函数需要实现。'''object_detection/core/model.py  def groundtruth_lists(self, field):    """Access list of groundtruth tensors."""  def groundtruth_has_field(self, field):    """Determines whether the groundtruth includes the given field."""  def provide_groundtruth(self,                          groundtruth_boxes_list,                          groundtruth_classes_list,                          groundtruth_masks_list=None,                          groundtruth_keypoints_list=None):    """Provide groundtruth tensors."""  @abstractmethod  def preprocess(self, inputs):  @abstractmethod  def predict(self, preprocessed_inputs)  @abstractmethod  def postprocess(self, prediction_dict, **params)  @abstractmethod  def loss(self, prediction_dict)  @abstractmethod  def restore_map(self, from_detection_checkpoint=True)
object_detection/meta_architectures/faster_rcnn_meta_arch.pyclass FasterRCNNFeatureExtractor(object):  """Faster R-CNN Feature Extractor definition."""  def __init__(self,               is_training,               first_stage_features_stride,               batch_norm_trainable=False,               reuse_weights=None,               weight_decay=0.0)  @abstractmethod  def preprocess(self, resized_inputs):    """Feature-extractor specific preprocessing (minus image resizing)."""  def extract_proposal_features(self, preprocessed_inputs, scope):    """Extracts first stage RPN features."""  @abstractmethod  def _extract_proposal_features(self, preprocessed_inputs, scope):  def extract_box_classifier_features(self, proposal_feature_maps, scope):    """Extracts second stage box classifier features."""  @abstractmethod  def _extract_box_classifier_features(self, proposal_feature_maps, scope):    """Extracts second stage box classifier features, to be overridden."""  def restore_from_classification_checkpoint_fn(      self,      first_stage_feature_extractor_scope,      second_stage_feature_extractor_scope):    """Returns a map of variables to load from a foreign checkpoint."""class FasterRCNNMetaArch(model.DetectionModel):  """Faster R-CNN Meta-architecture definition."""  """暂时主要看哪些地方调用了feature_extractor: A FasterRCNNFeatureExtractor object.换一个cnn还是比较简单的,只需要重写一个faster_rcnn_new_cnn_feature_extractor。最终构建的检测模型是这个类的对象。"""  def preprocess(self, inputs):  """For Faster R-CNN, we perform image resizing in the base class --- each    class subclassing FasterRCNNMetaArch is responsible for any additional    preprocessing (e.g., scaling pixel values to be in [-1, 1]).    见下面代码块中实现的preprocess函数"""
object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py"""这一块和slim结合紧密,我们仔细看看。"""class FasterRCNNResnetV1FeatureExtractor(    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):  """Faster R-CNN Resnet V1 feature extractor implementation."""    def __init__(self,               architecture,               resnet_model,               is_training,               first_stage_features_stride,               batch_norm_trainable=False,               reuse_weights=None,               weight_decay=0.0):    def preprocess(self, resized_inputs):    """Faster R-CNN Resnet V1 preprocessing."""        channel_means = [123.68, 116.779, 103.939]        return resized_inputs - [[channel_means]]    def _extract_proposal_features(self, preprocessed_inputs, scope):    """Extracts first stage RPN features.    使用endpoints输出resnet block3的值。    """    def _extract_box_classifier_features(self, proposal_feature_maps, scope):    """Extracts second stage box classifier features.    拆分出resnet的block4。注意variable_scope和arg_scope的使用。    """class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):  """Faster R-CNN Resnet 152 feature extractor implementation."""  def __init__(self,               is_training,               first_stage_features_stride,               batch_norm_trainable=False,               reuse_weights=None,               weight_decay=0.0):    """Constructor.    Args:      is_training: See base class.      first_stage_features_stride: See base class.      batch_norm_trainable: See base class.      reuse_weights: See base class.      weight_decay: See base class.    Raises:      ValueError: If `first_stage_features_stride` is not 8 or 16,        or if `architecture` is not supported.    """    super(FasterRCNNResnet152FeatureExtractor, self).__init__(        'resnet_v1_152', resnet_v1.resnet_v1_152, is_training,        first_stage_features_stride, batch_norm_trainable,        reuse_weights, weight_decay)    """往前看各个类的init,'resnet_v1_152', resnet_v1.resnet_v1_152只用在了上面的class FasterRCNNResnetV1FeatureExtractor"""

同样建议跑一跑test脚本。会遇到如下文件,按照test中出现的顺序逐个阅读这些文件,以及对应的test脚本。

"""Builder function to construct tf-slim arg_scope for convolution, fc ops.看一下这个脚本的test,很容易理解超参数配置是怎么读取的了,类似OpenFOAM中的dict。object_detection.protos.hyperparams_pb2.Hyperparams。"""from object_detection.builders import hyperparams_builder"""Contains routines for printing protocol messages in text format.同样是上面这个test脚本,目前主要用在    conv_hyperparams_proto = hyperparams_pb2.Hyperparams()text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)其中conv_hyperparams_text_proto是包含参数配置的字符串,conv_hyperparams_proto是hyperparams.proto object,hyperparams_builder.build的第一个参数。"""from google.protobuf import text_format"""Function to build box predictor from configuration.Box predictors are classes that take a high levelimage feature map as input and produce two predictions,(1) a tensor encoding box locations, and(2) a tensor encoding classes for each box.object_detection/core/box_predictor.py留待后续研读。注意conv_hyperparams_text_proto是放进box_predictor_text_proto然后一起传递给class ConvolutionalBoxPredictor(BoxPredictor)的。"""from object_detection.builders import box_predictor_builder"""Generates grid anchors on the fly as used in Faster RCNN.下次细看。"""from object_detection.anchor_generators import grid_anchor_generator"""Builder function for post processing operations."""from object_detection.builders import post_processing_builder"""Classification and regression loss functions for object detection."""from object_detection.core import losses"""proto文件,下次再结合相应的core和builder来具体研究如何编写和读取这些文件"""from object_detection.protos import box_predictor_pb2from object_detection.protos import hyperparams_pb2from object_detection.protos import post_processing_pb2
"""A function to build a DetectionModel from configuration.很多内容在faster_rcnn_meta_arch_test_lib.py测试过了。"""object_detection/builders/model_builder.py
原创粉丝点击