图像语义分割代码实现(2)

来源:互联网 发布:国内云计算厂商 编辑:程序博客网 时间:2024/06/10 08:26

针对《图像语义分割(2)- SegNet》介绍的 SegNet 算法,主要参考官方项目主页,在 CamVid 数据集上做训练和测试


框架安装

官方 SegNet 包含两个部分,一个是 SegNet 框架,一个是修改的 Caffe

1)clone SegNet-Tutorial

git clone https://github.com/alexgkendall/SegNet-Tutorial.git SegNet

里面包含了 CamVid 的训练数据集

2)clone caffe-segnet

git clone https://github.com/TimoSaemann/caffe-segnet-cudnn5.git caffe-segnet

这里有不支持 cudnn v2 的 caffe-segnet

3)将 caffe-segnet 移到 SegNet 里面,文件树如图1所示

       这里写图片描述
                        图1. SegNet 的代码框架


数据准备及训练

1)训练前的数据准备

  • 修改 CamVid/train.txt 和 CamVid/test.txt 里面的图片路径换成绝对路径
  • 修改 Models/segnet_train.prototxt 和 Models/segnet_inference.prototxt 第一层的 data 路径为绝对路径,batch_size 按显卡能力调整
  • 修改 Models/segnet_solver.protxt 修改 snapshot_prefix 为绝对路径
  • 编写训练脚本 train_it.sh
#!/bin/bash./caffe-segnet-cudnn5/build/tools/caffe train -gpu 0 -solver Models/segnet_solver.prototxt

2)编译 caffe-segnet 并开始训练

  • 修改 Makefile.config,取消 USE_CUDNN 和 WITH_PYTHON_LAYER 注释;添加 hdf5 的路径
# Whatever else you find you need goes here.INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include /usr/include/hdf5/serialLIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib /usr/lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu/hdf5/serial
  • make -j8 && make pycaffe
  • python train_it.py 开始训练

    3)测试 webcam_demo (可选)
    官方给了个 webcam 的 demo ,插上usb相机就能用

python Scripts/webcam_demo.py --model Example_Models/segnet_model_driving_webdemo.prototxt --weights /Example_Models/segnet_weights_driving_webdemo.caffemodel --colours /Scripts/camvid12.png

这里 colours 由一张 12563的图片表示,一共256种3通道数据

4)训练结束后,需要在模型中添加 bn 层

  • 首先修改 Scripts/compute_bn_statistics.py 和 Scripts/test_segmentation_camvid.py 中 caffe-segnet 的绝对路径
  • 直接修改 caffemodel,生成的文件为 Models/Inference/test_weights.caffemodel
python Scripts/compute_bn_statistics.py Models/segnet_train.prototxt CamVid/model/segnet_iter_40000.caffemodel Models/Inference/

测试

Scripts/test_segmentation_camvid.py 是一个批测试脚本。直接通过 Models/segnet_inference.prototxt 测试 CamVid/test.txt 的内容

那么如何测试单个图像或视频呢?这里参考 webcam_demo 和 test_segmentation_camvid.py 做如下修改

  • 创建 Models/segnet_deploy.protxt
    内容从 Models/segnet_inference.protxt 复制过来,修改第一层为
input: "data"input_dim: 1input_dim: 3input_dim: 360input_dim: 480
  • 创建 Scripts/webcam_my.py
import numpy as npimport matplotlib.pyplot as pltimport os.pathimport jsonimport scipyimport argparseimport mathimport pylabfrom sklearn.preprocessing import normalizecaffe_root = '/home/xxx/SegNet/caffe-segnet-cudnn5/'            # Change this to the absolute directoy to SegNet Caffeimport syssys.path.insert(0, caffe_root + 'python')import caffeimport cv2import time# Import argumentsparser = argparse.ArgumentParser()parser.add_argument('--model', type=str, required=True)parser.add_argument('--weights', type=str, required=True)args = parser.parse_args()caffe.set_mode_gpu()net = caffe.Net(args.model,                args.weights,                caffe.TEST)input_shape = net.blobs['data'].data.shapeoutput_shape = net.blobs['prob'].data.shapecv2.namedWindow("Input")cv2.namedWindow("SegNet")cap = cv2.VideoCapture(0) # Change this to your webcam ID, or file name for your video fileSky = [128,128,128]Building = [128,0,0]Pole = [192,192,128]Road_marking = [255,69,0]Road = [128,64,128]Pavement = [60,40,222]Tree = [128,128,0]SignSymbol = [192,128,128]Fence = [64,64,128]Car = [64,0,128]Pedestrian = [64,64,0]Bicyclist = [0,128,192]Unlabelled = [0,0,0]label_colours = np.array([Sky, Building, Pole, Road, Pavement, Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled])rval = Truewhile rval:    start = time.time()    rval, frame = cap.read()    if rval == False:        break    end = time.time()    print '%30s' % 'Grabbed camera frame in ', str((end - start)*1000), 'ms'    start = time.time()    frame = cv2.resize(frame, (input_shape[3],input_shape[2]))    input_image = frame.transpose((2,0,1))    input_image = np.asarray([input_image])    end = time.time()    print '%30s' % 'Resized image in ', str((end - start)*1000), 'ms'    start = time.time()    out = net.forward_all(data=input_image)    end = time.time()    print '%30s' % 'Executed SegNet in ', str((end - start)*1000), 'ms'    start = time.time()    image = net.blobs['data'].data    predicted = net.blobs['prob'].data    image = np.squeeze(image[0,:,:,:])    output = np.squeeze(predicted[0,:,:,:])    ind = np.argmax(output, axis=0)    r = ind.copy()    g = ind.copy()    b = ind.copy()    for l in range(0,11):        r[ind==l] = label_colours[l,0]        g[ind==l] = label_colours[l,1]        b[ind==l] = label_colours[l,2]    rgb = np.zeros((ind.shape[0], ind.shape[1], 3))    rgb[:,:,0] = r/255.0    rgb[:,:,1] = g/255.0    rgb[:,:,2] = b/255.0    image = image/255.0    image = np.transpose(image, (1,2,0))    output = np.transpose(output, (1,2,0))    image = image[:,:,(2,1,0)]    end = time.time()    print '%30s' % 'Processed results in ', str((end - start)*1000), 'ms\n'    cv2.imshow("Input", image)    cv2.imshow("SegNet", rgb)    key = cv2.waitKey(1)    if key == 27: # exit on ESC         breakcap.release()cv2.destroyAllWindows()
  • 创建测试脚本 test_it.sh
#!/bin/bashpython Scripts/webcam_my.py --model Models/segnet_deploy.prototxt --weights Models/Inference/test_weights.caffemodel

从测试效果来看(GTX1070),帧率还能接受


原创粉丝点击