YOLO实现之源码分析1

来源:互联网 发布:dh算法加密 编辑:程序博客网 时间:2024/06/07 13:18

前面介绍了论文的思想与模型,以及使用darknet实践,接下来让我们深入到源码。

main.c

首先追踪主函数。主函数开始解析命令行参数,然后根据不同的命令行参数进入不同的调用方法。接下来我们以YOLO为主线,进行追踪与分析。

int main(int argc, char **argv){    //test_resize("data/bad.jpg");    //test_box();    //test_convolutional_layer();    if(argc < 2){        fprintf(stderr, "usage: %s <function>\n", argv[0]);        return 0;    }    gpu_index = find_int_arg(argc, argv, "-i", 0);    if(find_arg(argc, argv, "-nogpu")) {        gpu_index = -1;    }#ifndef GPU    gpu_index = -1;#else    if(gpu_index >= 0){        cudaError_t status = cudaSetDevice(gpu_index);        check_error(status);    }#endif    if(0==strcmp(argv[1], "imagenet")){        run_imagenet(argc, argv);    } else if (0 == strcmp(argv[1], "average")){        average(argc, argv);    } else if (0 == strcmp(argv[1], "yolo")){    // 检测到命令行第一个参数输入的是yolo,则进入run_yolo函数        run_yolo(argc, argv);    } else if (0 == strcmp(argv[1], "cifar")){        run_cifar(argc, argv);    } else if (0 == strcmp(argv[1], "go")){        run_go(argc, argv);    } else if (0 == strcmp(argv[1], "rnn")){        run_char_rnn(argc, argv);    } else if (0 == strcmp(argv[1], "vid")){        run_vid_rnn(argc, argv);    } else if (0 == strcmp(argv[1], "coco")){        run_coco(argc, argv);    } else if (0 == strcmp(argv[1], "classifier")){        run_classifier(argc, argv);    } else if (0 == strcmp(argv[1], "art")){        run_art(argc, argv);    } else if (0 == strcmp(argv[1], "tag")){        run_tag(argc, argv);    } else if (0 == strcmp(argv[1], "compare")){        run_compare(argc, argv);    } else if (0 == strcmp(argv[1], "dice")){        run_dice(argc, argv);    } else if (0 == strcmp(argv[1], "writing")){        run_writing(argc, argv);    } else if (0 == strcmp(argv[1], "3d")){        composite_3d(argv[2], argv[3], argv[4]);    } else if (0 == strcmp(argv[1], "test")){        test_resize(argv[2]);    } else if (0 == strcmp(argv[1], "captcha")){        run_captcha(argc, argv);    } else if (0 == strcmp(argv[1], "nightmare")){        run_nightmare(argc, argv);    } else if (0 == strcmp(argv[1], "change")){        change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);    } else if (0 == strcmp(argv[1], "rgbgr")){        rgbgr_net(argv[2], argv[3], argv[4]);    } else if (0 == strcmp(argv[1], "denormalize")){        denormalize_net(argv[2], argv[3], argv[4]);    } else if (0 == strcmp(argv[1], "normalize")){        normalize_net(argv[2], argv[3], argv[4]);    } else if (0 == strcmp(argv[1], "rescale")){        rescale_net(argv[2], argv[3], argv[4]);    } else if (0 == strcmp(argv[1], "partial")){        partial(argv[2], argv[3], argv[4], atoi(argv[5]));    } else if (0 == strcmp(argv[1], "stacked")){        stacked(argv[2], argv[3], argv[4]);    } else if (0 == strcmp(argv[1], "visualize")){        visualize(argv[2], (argc > 3) ? argv[3] : 0);    } else if (0 == strcmp(argv[1], "imtest")){        test_resize(argv[2]);    } else {        fprintf(stderr, "Not an option: %s\n", argv[1]);    }    return 0;}

yolo.c

接下来分析yolo.c文件,该文件主要是使用darknet 实现YOLO论文的思想。

run_yolo(int argc, char **argv)

该函数是yolo.c文件的主函数。

char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};void run_yolo(int argc, char **argv){    int i;    // 1. 提取类别标签    for(i = 0; i < 20; ++i){        char buff[256];        sprintf(buff, "data/labels/%s.png", voc_names[i]);        voc_labels[i] = load_image_color(buff, 0, 0);    }    // 2.读取得分阈值,默认是0.2    float thresh = find_float_arg(argc, argv, "-thresh", .2);    // 3. 读取是否指定摄像头,默认是0    int cam_index = find_int_arg(argc, argv, "-c", 0);    if(argc < 4){        fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);        return;    }    char *cfg = argv[3];  // 读取配置文件    // 读取权重文件,默认是0    char *weights = (argc > 4) ? argv[4] : 0;      // 读取测试文件,默认是0    char *filename = (argc > 5) ? argv[5]: 0;    // 根据不同的需求(test,train,validate,demo),进入不同的函数    if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);    else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);    else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);    else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);    else if(0==strcmp(argv[2], "demo")) demo_yolo(cfg, weights, thresh, cam_index, filename);}

test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)

该函数主要

void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh){    // 解析配置文件的网络结构    network net = parse_network_cfg(cfgfile);    // 加载权重文件到网络    if(weightfile){        load_weights(&net, weightfile);    }    detection_layer l = net.layers[net.n-1];    set_batch_network(&net, 1);    srand(2222222);    clock_t time;    char buff[256];    char *input = buff;    int j;    float nms=.5;    bounding_box *boxes = calloc(l.side*l.side*l.n, sizeof(bounding_box));    float **probs = calloc(l.side*l.side*l.n, sizeof(float *));    for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));    while(1){        if(filename){            strncpy(input, filename, 256);        } else {            printf("Enter Image Path: ");            fflush(stdout);            input = fgets(input, 256, stdin);            if(!input) return;            strtok(input, "\n");        }        image im = load_image_color(input,0,0);        image sized = resize_image(im, net.w, net.h);        float *X = sized.data;        time=clock();        float *predictions = network_predict(net, X);        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));        convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);        if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);        //draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);        draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);        show_image(im, "predictions");        save_image(im, "predictions");        show_image(sized, "resized");        free_image(im);        free_image(sized);#ifdef OPENCV        cvWaitKey(0);        cvDestroyAllWindows();#endif        if (filename) break;    }}

yolo_demo.c

yolo_demo.c文件主要实现yolo处理视频文件。

demo_yolo(char *cfgfile, char *weightfile, float thresh, int cam_index, char *filename)

demo_yolo函数是演示通过摄像头去做目标检测。

void demo_yolo(char *cfgfile, char *weightfile, float thresh, int cam_index, char *filename){    demo_thresh = thresh;    printf("YOLO demo\n");    net = parse_network_cfg(cfgfile);    if(weightfile){        load_weights(&net, weightfile);    }    set_batch_network(&net, 1);    srand(2222222);    // 若有视频文件,打开视频文件;若没有视频文件,打开摄像头。    if(filename){        cap = cvCaptureFromFile(filename);    }else{    //驱动摄像头        cap = cvCaptureFromCAM(cam_index);    }    // 打开摄像头    if(!cap) error("Couldn't connect to webcam.\n");    cvNamedWindow("YOLO", CV_WINDOW_NORMAL);     cvResizeWindow("YOLO", 512, 512);    detection_layer l = net.layers[net.n-1];    int j;    // 开辟存储空间    boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box));    probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *));    for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *));    // 多线程编程    pthread_t fetch_thread;    pthread_t detect_thread;    fetch_in_thread(0);    det = in;    det_s = in_s;    fetch_in_thread(0);    detect_in_thread(0);    disp = det;    det = in;    det_s = in_s;    while(1){        struct timeval tval_before, tval_after, tval_result;        gettimeofday(&tval_before, NULL);        if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");        if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");        show_image(disp, "YOLO");        free_image(disp);        cvWaitKey(1);        pthread_join(fetch_thread, 0);        pthread_join(detect_thread, 0);        disp  = det;        det   = in;        det_s = in_s;        gettimeofday(&tval_after, NULL);        timersub(&tval_after, &tval_before, &tval_result);        float curr = 1000000.f/((long int)tval_result.tv_usec);        fps = .9*fps + .1*curr;    }}

fetch_in_thread(void *ptr)

void *fetch_in_thread(void *ptr){    in = get_image_from_stream(cap);    in_s = resize_image(in, net.w, net.h);    return 0;}

detect_in_thread(void *ptr)

void *detect_in_thread(void *ptr){    float nms = .4;    detection_layer l = net.layers[net.n-1];    float *X = det_s.data;    float *predictions = network_predict(net, X);    free_image(det_s);    convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, demo_thresh, probs, boxes, 0);    if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms);    printf("\033[2J");    printf("\033[1;1H");    printf("\nFPS:%.0f\n",fps);    printf("Objects:\n\n");    draw_detections(det, l.side*l.side*l.n, demo_thresh, boxes, probs, voc_names, voc_labels, 20);    return 0;}

image.c

image.c文件主要处理图像相关的工作。

get_image_from_stream(CvCapture *cap)

get_image_from_stream函数主要完成从视频内抽帧的工作。

image get_image_from_stream(CvCapture *cap)    {        IplImage* src = cvQueryFrame(cap);        image im = ipl_to_image(src);        rgbgr_image(im);        return im;    }

resize_image(image im, int w, int h)

image resize_image(image im, int w, int h){    image resized = make_image(w, h, im.c);       image part = make_image(w, im.h, im.c);    int r, c, k;    float w_scale = (float)(im.w - 1) / (w - 1);    float h_scale = (float)(im.h - 1) / (h - 1);    for(k = 0; k < im.c; ++k){        for(r = 0; r < im.h; ++r){            for(c = 0; c < w; ++c){                float val = 0;                if(c == w-1 || im.w == 1){                    val = get_pixel(im, im.w-1, r, k);                } else {                    float sx = c*w_scale;                    int ix = (int) sx;                    float dx = sx - ix;                    val = (1 - dx) * get_pixel(im, ix, r, k) + dx * get_pixel(im, ix+1, r, k);                }                set_pixel(part, c, r, k, val);            }        }    }    for(k = 0; k < im.c; ++k){        for(r = 0; r < h; ++r){            float sy = r*h_scale;            int iy = (int) sy;            float dy = sy - iy;            for(c = 0; c < w; ++c){                float val = (1-dy) * get_pixel(part, c, iy, k);                set_pixel(resized, c, r, k, val);            }            if(r == h-1 || im.h == 1) continue;            for(c = 0; c < w; ++c){                float val = dy * get_pixel(part, c, iy+1, k);                add_pixel(resized, c, r, k, val);            }        }    }    free_image(part);    return resized;}

重要数据结构定义

network

typedef struct network{    float *workspace;    int n;    int batch;    int *seen;    float epoch;    int subdivisions;    float momentum;    float decay;    layer *layers;    int outputs;    float *output;    learning_rate_policy policy;    float learning_rate;    float gamma;    float scale;    float power;    int time_steps;    int step;    int max_batches;    float *scales;    int *steps;    int num_steps;    int inputs;    int h, w, c;    int max_crop;    int min_crop;    #ifdef GPU    float **input_gpu;    float **truth_gpu;    #endif} network;

layer

struct layer{    LAYER_TYPE type;    ACTIVATION activation;    COST_TYPE cost_type;    int batch_normalize;    int shortcut;    int batch;    int forced;    int flipped;    int inputs;    int outputs;    int truths;    int h,w,c;    int out_h, out_w, out_c;    int n;    int max_boxes;    int groups;    int size;    int side;    int stride;    int pad;    int sqrt;    int flip;    int index;    int binary;    int xnor;    int steps;    int hidden;    float dot;    float angle;    float jitter;    float saturation;    float exposure;    float shift;    int softmax;    int classes;    int coords;    int background;    int rescore;    int objectness;    int does_cost;    int joint;    int noadjust;    float alpha;    float beta;    float kappa;    float coord_scale;    float object_scale;    float noobject_scale;    float class_scale;    int dontload;    int dontloadscales;    float temperature;    float probability;    float scale;    int *indexes;    float *rand;    float *cost;    float *filters;    char  *cfilters;    float *filter_updates;    float *state;    float *state_delta;    float *concat;    float *concat_delta;    float *binary_filters;    float *biases;    float *bias_updates;    float *scales;    float *scale_updates;    float *weights;    float *weight_updates;    float *col_image;    int   * input_layers;    int   * input_sizes;    float * delta;    float * output;    float * squared;    float * norms;    float * spatial_mean;    float * mean;    float * variance;    float * mean_delta;    float * variance_delta;    float * rolling_mean;    float * rolling_variance;    float * x;    float * x_norm;    struct layer *input_layer;    struct layer *self_layer;    struct layer *output_layer;    struct layer *input_gate_layer;    struct layer *state_gate_layer;    struct layer *input_save_layer;    struct layer *state_save_layer;    struct layer *input_state_layer;    struct layer *state_state_layer;    struct layer *input_z_layer;    struct layer *state_z_layer;    struct layer *input_r_layer;    struct layer *state_r_layer;    struct layer *input_h_layer;    struct layer *state_h_layer;    size_t workspace_size;    #ifdef GPU    float *z_gpu;    float *r_gpu;    float *h_gpu;    int *indexes_gpu;    float * prev_state_gpu;    float * forgot_state_gpu;    float * forgot_delta_gpu;    float * state_gpu;    float * state_delta_gpu;    float * gate_gpu;    float * gate_delta_gpu;    float * save_gpu;    float * save_delta_gpu;    float * concat_gpu;    float * concat_delta_gpu;    float * filters_gpu;    float * filter_updates_gpu;    float *binary_input_gpu;    float *binary_filters_gpu;    float * mean_gpu;    float * variance_gpu;    float * rolling_mean_gpu;    float * rolling_variance_gpu;    float * variance_delta_gpu;    float * mean_delta_gpu;    float * col_image_gpu;    float * x_gpu;    float * x_norm_gpu;    float * weights_gpu;    float * weight_updates_gpu;    float * biases_gpu;    float * bias_updates_gpu;    float * scales_gpu;    float * scale_updates_gpu;    float * output_gpu;    float * delta_gpu;    float * rand_gpu;    float * squared_gpu;    float * norms_gpu;    #ifdef CUDNN    cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;    cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;    cudnnFilterDescriptor_t filterDesc;    cudnnFilterDescriptor_t dfilterDesc;    cudnnConvolutionDescriptor_t convDesc;    cudnnConvolutionFwdAlgo_t fw_algo;    cudnnConvolutionBwdDataAlgo_t bd_algo;    cudnnConvolutionBwdFilterAlgo_t bf_algo;    #endif    #endif};
0 0
原创粉丝点击