YOLO实现之源码分析1
来源:互联网 发布:dh算法加密 编辑:程序博客网 时间:2024/06/07 13:18
前面介绍了论文的思想与模型,以及使用darknet
实践,接下来让我们深入到源码。
main.c
首先追踪主函数。主函数开始解析命令行参数,然后根据不同的命令行参数进入不同的调用方法。接下来我们以YOLO为主线,进行追踪与分析。
int main(int argc, char **argv){ //test_resize("data/bad.jpg"); //test_box(); //test_convolutional_layer(); if(argc < 2){ fprintf(stderr, "usage: %s <function>\n", argv[0]); return 0; } gpu_index = find_int_arg(argc, argv, "-i", 0); if(find_arg(argc, argv, "-nogpu")) { gpu_index = -1; }#ifndef GPU gpu_index = -1;#else if(gpu_index >= 0){ cudaError_t status = cudaSetDevice(gpu_index); check_error(status); }#endif if(0==strcmp(argv[1], "imagenet")){ run_imagenet(argc, argv); } else if (0 == strcmp(argv[1], "average")){ average(argc, argv); } else if (0 == strcmp(argv[1], "yolo")){ // 检测到命令行第一个参数输入的是yolo,则进入run_yolo函数 run_yolo(argc, argv); } else if (0 == strcmp(argv[1], "cifar")){ run_cifar(argc, argv); } else if (0 == strcmp(argv[1], "go")){ run_go(argc, argv); } else if (0 == strcmp(argv[1], "rnn")){ run_char_rnn(argc, argv); } else if (0 == strcmp(argv[1], "vid")){ run_vid_rnn(argc, argv); } else if (0 == strcmp(argv[1], "coco")){ run_coco(argc, argv); } else if (0 == strcmp(argv[1], "classifier")){ run_classifier(argc, argv); } else if (0 == strcmp(argv[1], "art")){ run_art(argc, argv); } else if (0 == strcmp(argv[1], "tag")){ run_tag(argc, argv); } else if (0 == strcmp(argv[1], "compare")){ run_compare(argc, argv); } else if (0 == strcmp(argv[1], "dice")){ run_dice(argc, argv); } else if (0 == strcmp(argv[1], "writing")){ run_writing(argc, argv); } else if (0 == strcmp(argv[1], "3d")){ composite_3d(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "test")){ test_resize(argv[2]); } else if (0 == strcmp(argv[1], "captcha")){ run_captcha(argc, argv); } else if (0 == strcmp(argv[1], "nightmare")){ run_nightmare(argc, argv); } else if (0 == strcmp(argv[1], "change")){ change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0); } else if (0 == strcmp(argv[1], "rgbgr")){ rgbgr_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "denormalize")){ denormalize_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "normalize")){ normalize_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "rescale")){ rescale_net(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "partial")){ partial(argv[2], argv[3], argv[4], atoi(argv[5])); } else if (0 == strcmp(argv[1], "stacked")){ stacked(argv[2], argv[3], argv[4]); } else if (0 == strcmp(argv[1], "visualize")){ visualize(argv[2], (argc > 3) ? argv[3] : 0); } else if (0 == strcmp(argv[1], "imtest")){ test_resize(argv[2]); } else { fprintf(stderr, "Not an option: %s\n", argv[1]); } return 0;}
yolo.c
接下来分析yolo.c文件,该文件主要是使用darknet
实现YOLO论文的思想。
run_yolo(int argc, char **argv)
该函数是yolo.c文件的主函数。
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};void run_yolo(int argc, char **argv){ int i; // 1. 提取类别标签 for(i = 0; i < 20; ++i){ char buff[256]; sprintf(buff, "data/labels/%s.png", voc_names[i]); voc_labels[i] = load_image_color(buff, 0, 0); } // 2.读取得分阈值,默认是0.2 float thresh = find_float_arg(argc, argv, "-thresh", .2); // 3. 读取是否指定摄像头,默认是0 int cam_index = find_int_arg(argc, argv, "-c", 0); if(argc < 4){ fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]); return; } char *cfg = argv[3]; // 读取配置文件 // 读取权重文件,默认是0 char *weights = (argc > 4) ? argv[4] : 0; // 读取测试文件,默认是0 char *filename = (argc > 5) ? argv[5]: 0; // 根据不同的需求(test,train,validate,demo),进入不同的函数 if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh); else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights); else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights); else if(0==strcmp(argv[2], "demo")) demo_yolo(cfg, weights, thresh, cam_index, filename);}
test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
该函数主要
void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh){ // 解析配置文件的网络结构 network net = parse_network_cfg(cfgfile); // 加载权重文件到网络 if(weightfile){ load_weights(&net, weightfile); } detection_layer l = net.layers[net.n-1]; set_batch_network(&net, 1); srand(2222222); clock_t time; char buff[256]; char *input = buff; int j; float nms=.5; bounding_box *boxes = calloc(l.side*l.side*l.n, sizeof(bounding_box)); float **probs = calloc(l.side*l.side*l.n, sizeof(float *)); for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *)); while(1){ if(filename){ strncpy(input, filename, 256); } else { printf("Enter Image Path: "); fflush(stdout); input = fgets(input, 256, stdin); if(!input) return; strtok(input, "\n"); } image im = load_image_color(input,0,0); image sized = resize_image(im, net.w, net.h); float *X = sized.data; time=clock(); float *predictions = network_predict(net, X); printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time)); convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0); if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms); //draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20); draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, voc_labels, 20); show_image(im, "predictions"); save_image(im, "predictions"); show_image(sized, "resized"); free_image(im); free_image(sized);#ifdef OPENCV cvWaitKey(0); cvDestroyAllWindows();#endif if (filename) break; }}
yolo_demo.c
yolo_demo.c文件主要实现yolo处理视频文件。
demo_yolo(char *cfgfile, char *weightfile, float thresh, int cam_index, char *filename)
demo_yolo函数是演示通过摄像头去做目标检测。
void demo_yolo(char *cfgfile, char *weightfile, float thresh, int cam_index, char *filename){ demo_thresh = thresh; printf("YOLO demo\n"); net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } set_batch_network(&net, 1); srand(2222222); // 若有视频文件,打开视频文件;若没有视频文件,打开摄像头。 if(filename){ cap = cvCaptureFromFile(filename); }else{ //驱动摄像头 cap = cvCaptureFromCAM(cam_index); } // 打开摄像头 if(!cap) error("Couldn't connect to webcam.\n"); cvNamedWindow("YOLO", CV_WINDOW_NORMAL); cvResizeWindow("YOLO", 512, 512); detection_layer l = net.layers[net.n-1]; int j; // 开辟存储空间 boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box)); probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *)); for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *)); // 多线程编程 pthread_t fetch_thread; pthread_t detect_thread; fetch_in_thread(0); det = in; det_s = in_s; fetch_in_thread(0); detect_in_thread(0); disp = det; det = in; det_s = in_s; while(1){ struct timeval tval_before, tval_after, tval_result; gettimeofday(&tval_before, NULL); if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed"); if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed"); show_image(disp, "YOLO"); free_image(disp); cvWaitKey(1); pthread_join(fetch_thread, 0); pthread_join(detect_thread, 0); disp = det; det = in; det_s = in_s; gettimeofday(&tval_after, NULL); timersub(&tval_after, &tval_before, &tval_result); float curr = 1000000.f/((long int)tval_result.tv_usec); fps = .9*fps + .1*curr; }}
fetch_in_thread(void *ptr)
void *fetch_in_thread(void *ptr){ in = get_image_from_stream(cap); in_s = resize_image(in, net.w, net.h); return 0;}
detect_in_thread(void *ptr)
void *detect_in_thread(void *ptr){ float nms = .4; detection_layer l = net.layers[net.n-1]; float *X = det_s.data; float *predictions = network_predict(net, X); free_image(det_s); convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, demo_thresh, probs, boxes, 0); if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms); printf("\033[2J"); printf("\033[1;1H"); printf("\nFPS:%.0f\n",fps); printf("Objects:\n\n"); draw_detections(det, l.side*l.side*l.n, demo_thresh, boxes, probs, voc_names, voc_labels, 20); return 0;}
image.c
image.c文件主要处理图像相关的工作。
get_image_from_stream(CvCapture *cap)
get_image_from_stream函数主要完成从视频内抽帧的工作。
image get_image_from_stream(CvCapture *cap) { IplImage* src = cvQueryFrame(cap); image im = ipl_to_image(src); rgbgr_image(im); return im; }
resize_image(image im, int w, int h)
image resize_image(image im, int w, int h){ image resized = make_image(w, h, im.c); image part = make_image(w, im.h, im.c); int r, c, k; float w_scale = (float)(im.w - 1) / (w - 1); float h_scale = (float)(im.h - 1) / (h - 1); for(k = 0; k < im.c; ++k){ for(r = 0; r < im.h; ++r){ for(c = 0; c < w; ++c){ float val = 0; if(c == w-1 || im.w == 1){ val = get_pixel(im, im.w-1, r, k); } else { float sx = c*w_scale; int ix = (int) sx; float dx = sx - ix; val = (1 - dx) * get_pixel(im, ix, r, k) + dx * get_pixel(im, ix+1, r, k); } set_pixel(part, c, r, k, val); } } } for(k = 0; k < im.c; ++k){ for(r = 0; r < h; ++r){ float sy = r*h_scale; int iy = (int) sy; float dy = sy - iy; for(c = 0; c < w; ++c){ float val = (1-dy) * get_pixel(part, c, iy, k); set_pixel(resized, c, r, k, val); } if(r == h-1 || im.h == 1) continue; for(c = 0; c < w; ++c){ float val = dy * get_pixel(part, c, iy+1, k); add_pixel(resized, c, r, k, val); } } } free_image(part); return resized;}
重要数据结构定义
network
typedef struct network{ float *workspace; int n; int batch; int *seen; float epoch; int subdivisions; float momentum; float decay; layer *layers; int outputs; float *output; learning_rate_policy policy; float learning_rate; float gamma; float scale; float power; int time_steps; int step; int max_batches; float *scales; int *steps; int num_steps; int inputs; int h, w, c; int max_crop; int min_crop; #ifdef GPU float **input_gpu; float **truth_gpu; #endif} network;
layer
struct layer{ LAYER_TYPE type; ACTIVATION activation; COST_TYPE cost_type; int batch_normalize; int shortcut; int batch; int forced; int flipped; int inputs; int outputs; int truths; int h,w,c; int out_h, out_w, out_c; int n; int max_boxes; int groups; int size; int side; int stride; int pad; int sqrt; int flip; int index; int binary; int xnor; int steps; int hidden; float dot; float angle; float jitter; float saturation; float exposure; float shift; int softmax; int classes; int coords; int background; int rescore; int objectness; int does_cost; int joint; int noadjust; float alpha; float beta; float kappa; float coord_scale; float object_scale; float noobject_scale; float class_scale; int dontload; int dontloadscales; float temperature; float probability; float scale; int *indexes; float *rand; float *cost; float *filters; char *cfilters; float *filter_updates; float *state; float *state_delta; float *concat; float *concat_delta; float *binary_filters; float *biases; float *bias_updates; float *scales; float *scale_updates; float *weights; float *weight_updates; float *col_image; int * input_layers; int * input_sizes; float * delta; float * output; float * squared; float * norms; float * spatial_mean; float * mean; float * variance; float * mean_delta; float * variance_delta; float * rolling_mean; float * rolling_variance; float * x; float * x_norm; struct layer *input_layer; struct layer *self_layer; struct layer *output_layer; struct layer *input_gate_layer; struct layer *state_gate_layer; struct layer *input_save_layer; struct layer *state_save_layer; struct layer *input_state_layer; struct layer *state_state_layer; struct layer *input_z_layer; struct layer *state_z_layer; struct layer *input_r_layer; struct layer *state_r_layer; struct layer *input_h_layer; struct layer *state_h_layer; size_t workspace_size; #ifdef GPU float *z_gpu; float *r_gpu; float *h_gpu; int *indexes_gpu; float * prev_state_gpu; float * forgot_state_gpu; float * forgot_delta_gpu; float * state_gpu; float * state_delta_gpu; float * gate_gpu; float * gate_delta_gpu; float * save_gpu; float * save_delta_gpu; float * concat_gpu; float * concat_delta_gpu; float * filters_gpu; float * filter_updates_gpu; float *binary_input_gpu; float *binary_filters_gpu; float * mean_gpu; float * variance_gpu; float * rolling_mean_gpu; float * rolling_variance_gpu; float * variance_delta_gpu; float * mean_delta_gpu; float * col_image_gpu; float * x_gpu; float * x_norm_gpu; float * weights_gpu; float * weight_updates_gpu; float * biases_gpu; float * bias_updates_gpu; float * scales_gpu; float * scale_updates_gpu; float * output_gpu; float * delta_gpu; float * rand_gpu; float * squared_gpu; float * norms_gpu; #ifdef CUDNN cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc; cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc; cudnnFilterDescriptor_t filterDesc; cudnnFilterDescriptor_t dfilterDesc; cudnnConvolutionDescriptor_t convDesc; cudnnConvolutionFwdAlgo_t fw_algo; cudnnConvolutionBwdDataAlgo_t bd_algo; cudnnConvolutionBwdFilterAlgo_t bf_algo; #endif #endif};
0 0
- YOLO实现之源码分析1
- YOLO源码分析之训练
- YOLO源码分析之data.c
- YOLO源码分析之detector.c
- yolo源码分析之demo.py
- yolo源码分析
- yolo 源码分析
- YOLO源码解析之yolo.c
- yolo源码学习 1
- yolo v2 源码分析(一)
- darknet yolo源码解读
- YOLO算法回归模型之回归的分析
- Tomcat源码分析之:ServletOutputStream的实现
- Java集合之HashMap源码实现分析
- nginx源码分析之http解码实现
- ceph源码分析之Log实现
- Cocos2dx源码分析之JumpBy的实现
- Java集合之HashMap源码实现分析
- WebSocket入门教程(五)-- WebSocket实例:简单多人聊天室
- ZBrush刻画脸部的这些要领看你是否了解
- Navicat for MySQL 选项设置有哪些技巧
- [转]大型网站架构系列:分布式消息队列(一)
- javascript 失去焦点(onblur)与获得焦点(onfocus),载入焦点(oInp.focus()),取消焦点(oInp.blur()),全选select()
- YOLO实现之源码分析1
- JavaScript event兼容问题及ev.clientX左距离,ev.clientY上距离,及onmousemove移动事件
- 使用IntelliJ IDEA,gradle开发Java web应用步骤
- Apache服务器主配置文件httpd.conf详解2
- Activity的四种启动模式
- 从阿里、腾讯系软件产品体验差距说起
- Thinkphp设置session共享使用 redis实现
- 算法及正则表达式
- JavaScript 冒泡原理及示例