TensorFlow object detection api------ssd_mobilenet使用

来源:互联网 发布:json和javascript 编辑:程序博客网 时间:2024/06/07 02:06

谷歌发布object detection api(https://github.com/tensorflow/models)已经有一段时间了,这个api的发布,让我们不用自己再去为faster-rcnn、ssd、rfcn单独造轮子了。

现记录一下以ssd_mobilenet这个模型的object detection api使用。

一、安装TensorFlow、以及用到的库

    按照官方指导操作即可:https://github.com/tensorflow/models/blob/master/object_detection/g3doc/installation.md。

二、下载models、ssd_mobilenet_v1_coco

    下载地址:https://github.com/tensorflow/models、https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md

    1、解压后,进入models目录,执行:

    protoc object_detection/protos/*.proto --python_out=.

    export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

    2、解压ssd_mobilenet_v1_coco_11_06_2017.tar.gz到,得到:frozen_inference_graph.pb  graph.pbtxt  model.ckpt.data-00000-of-00001  model.ckpt.index  model.ckpt.meta,

二、准备数据集

    1、使用将Pascal voc格式的数据集转换成record格式,即生成train.record、eval.record,参考models/object_detection/create_pascal_tf_record.py

    2、创建工程目录:

    +SSD_MOBILENET

        +data
            -label_map file(根据models/object_detection/ models/object_detection/data/pascal_label_map.pbtxt来修改)
            -train TFRecord file(即生成的train.record)
            -eval TFRecord file(即生成的eval.record)

            -model.ckpt(将解压ssd_mobilenet_v1_coco_11_06_2017.tar.gz得到的model.ckpt.data-00000-of-00001重命名为mode.ckpt)
            -model.ckpt.index(解压ssd_mobilenet_v1_coco_11_06_2017.tar.gz得到)
            -model.ckpt.meta(解压ssd_mobilenet_v1_coco_11_06_2017.tar.gz得到)

        +models
            + model
                -pipeline config file(根据models/object_detection/samples/configs/ssd_mobilenet_v1_pets.config来修改,具体到下面说明)
                +train
                +eval

    3、进入models目录执行:

     python object_detection/train.py \
     --logtostderr \
     --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \
     --train_dir=${PATH_TO_TRAIN_DIR}

    开始训练。

    其中:PATH_TO_YOUR_PIPELINE_CONFIG为SSD_MOBILENET/models/model/pipeline config file,PATH_TO_TRAIN_DIR为SSD_MOBILENET/models/model/train

    关于pipeline config file内容修改如下:

    第9行:num_classes:,你数据集有几类就改成几类

    第158行:fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt",即SSD_MOBILENET/data/model.ckpt

    第177行:input_path: "PATH_TO_BE_CONFIGURED/pet_train.record",即SSD_MOBILENET/data/train.record

    第179行:label_map_path: "PATH_TO_BE_CONFIGURED/pet_label_map.pbtxt",即SSD_MOBILENET/data/label_map file

    第191行:input_path: "PATH_TO_BE_CONFIGURED/pet_val.record",即SSD_MOBILENET/data/eval.record

    第193行:label_map_path: "PATH_TO_BE_CONFIGURED/pet_label_map.pbtxt",即SSD_MOBILENET/data/label_map file

    第183行:num_examples: ,eval有多少张就改成多少张

原创粉丝点击