category命令评估YOLO模型对每种物体检测的性能

来源:互联网 发布:php简单文字特效代码 编辑:程序博客网 时间:2024/06/10 15:21

将下面代码添加到darknet/src/detector.c中:

void print_category(FILE **fps, char *path, box *boxes, float **probs, int total, int classes, int w, int h, float thresh, float iou_thresh){    int i, j;    char labelpath[4096];    find_replace(path, "images", "labels", labelpath);    find_replace(labelpath, "JPEGImages", "labels", labelpath);    find_replace(labelpath, ".jpg", ".txt", labelpath);    find_replace(labelpath, ".JPEG", ".txt", labelpath);    int num_labels = 0;    box_label *truth = read_boxes(labelpath, &num_labels);    for (i = 0; i < total; ++i){        int class_id = max_index(probs[i], classes);        float prob = probs[i][class_id];        if (prob < thresh)continue;        float best_iou = 0;        int best_iou_id = 0;        int correct = 0;        for (j = 0; j < num_labels; ++j) {            box t = { truth[j].x*w, truth[j].y*h, truth[j].w*w, truth[j].h*h };            float iou = box_iou(boxes[i], t);            //fprintf(stderr, "box p: %f, %f, %f, %f\n", boxes[i].x, boxes[i].y, boxes[i].w, boxes[i].h);            //fprintf(stderr, "box t: %f, %f, %f, %f\n", t.x, t.y, t.w, t.h);            //fprintf(stderr, "iou : %f\n", iou);            if (iou > best_iou){                best_iou = iou;                best_iou_id = j;            }        }        if (best_iou > iou_thresh && truth[best_iou_id].id == class_id){            correct = 1;        }        float xmin = boxes[i].x - boxes[i].w / 2.;        float xmax = boxes[i].x + boxes[i].w / 2.;        float ymin = boxes[i].y - boxes[i].h / 2.;        float ymax = boxes[i].y + boxes[i].h / 2.;        if (xmin < 0) xmin = 0;        if (ymin < 0) ymin = 0;        if (xmax > w) xmax = w;        if (ymax > h) ymax = h;        fprintf(fps[class_id], "%s, %d, %d, %f, %f, %f, %f, %f, %f\n", path, class_id, correct, prob, best_iou, xmin, ymin, xmax, ymax);    }}void validate_detector_category(char *datacfg, char *cfgfile, char *weightfile, char *outfile){    network net = parse_network_cfg(cfgfile);    int j;    list *options = read_data_cfg(datacfg);    char *valid_images = option_find_str(options, "valid", "data/train.list");    char *name_list = option_find_str(options, "names", "data/names.list");    char *prefix = option_find_str(options, "results", "results");    char **names = get_labels(name_list);    char *mapf = option_find_str(options, "map", 0);    int *map = 0;    if (mapf) map = read_map(mapf);    if (weightfile){        load_weights(&net, weightfile);    }    set_batch_network(&net, 1);    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);    srand(time(0));    list *plist = get_paths(valid_images);    char **paths = (char **)list_to_array(plist);    layer l = net.layers[net.n - 1];    int classes = l.classes;    char buff[1024];    FILE **fps = 0;    if (!outfile) outfile = "paul_";    fps = calloc(classes, sizeof(FILE *));    for (j = 0; j < classes; ++j){        _snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);        fps[j] = fopen(buff, "w");    }    box *boxes = calloc(l.w*l.h*l.n, sizeof(box));    float **probs = calloc(l.w*l.h*l.n, sizeof(float *));    for (j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));    int m = plist->size;    int i = 0;    int t;    float thresh = .25;    float iou_thresh = .5;    float nms = .45;    int nthreads = 4;    image *val = calloc(nthreads, sizeof(image));    image *val_resized = calloc(nthreads, sizeof(image));    image *buf = calloc(nthreads, sizeof(image));    image *buf_resized = calloc(nthreads, sizeof(image));    pthread_t *thr = calloc(nthreads, sizeof(pthread_t));    load_args args = { 0 };    args.w = net.w;    args.h = net.h;    args.type = IMAGE_DATA;    for (t = 0; t < nthreads; ++t){        args.path = paths[i + t];        args.im = &buf[t];        args.resized = &buf_resized[t];        thr[t] = load_data_in_thread(args);    }    time_t start = time(0);    for (i = nthreads; i < m + nthreads; i += nthreads){        fprintf(stderr, "%d\n", i);        for (t = 0; t < nthreads && i + t - nthreads < m; ++t){            pthread_join(thr[t], 0);            val[t] = buf[t];            val_resized[t] = buf_resized[t];        }        for (t = 0; t < nthreads && i + t < m; ++t){            args.path = paths[i + t];            args.im = &buf[t];            args.resized = &buf_resized[t];            thr[t] = load_data_in_thread(args);        }        for (t = 0; t < nthreads && i + t - nthreads < m; ++t){            char *path = paths[i + t - nthreads];            float *X = val_resized[t].data;            network_predict(net, X);            int w = val[t].w;            int h = val[t].h;            get_region_boxes(l, w, h, thresh, probs, boxes, 0, map, .5, 0);            if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);            print_category(fps, path, boxes, probs, l.w*l.h*l.n, classes, w, h, thresh, iou_thresh);            free_image(val[t]);            free_image(val_resized[t]);        }    }    for (j = 0; j < classes; ++j){        if (fps) fclose(fps[j]);    }    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));}

修改 run_detector()函数:

void run_detector(int argc, char **argv){char *prefix = find_char_arg(argc, argv, "-prefix", 0);float thresh = find_float_arg(argc, argv, "-thresh", .24);float hier_thresh = find_float_arg(argc, argv, "-hier", .5);int cam_index = find_int_arg(argc, argv, "-c", 0);int frame_skip = find_int_arg(argc, argv, "-s", 0);if(argc < 4){fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);return;}char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);char *outfile = find_char_arg(argc, argv, "-out", 0);int *gpus = 0;int gpu = 0;int ngpus = 0;if(gpu_list){printf("%s\n", gpu_list);int len = strlen(gpu_list);ngpus = 1;int i;for(i = 0; i < len; ++i){if (gpu_list[i] == ',') ++ngpus;}gpus = calloc(ngpus, sizeof(int));for(i = 0; i < ngpus; ++i){gpus[i] = atoi(gpu_list);gpu_list = strchr(gpu_list, ',')+1;}} else {gpu = gpu_index;gpus = &gpu;ngpus = 1;}int clear = find_arg(argc, argv, "-clear");char *datacfg = argv[3];char *cfg = argv[4];char *weights = (argc > 5) ? argv[5] : 0;char *filename = (argc > 6) ? argv[6]: 0;if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, outfile);else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);else if(0==strcmp(argv[2], "valid2")) validate_detector_flip(datacfg, cfg, weights, outfile);else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);    //yim 2017.05.16    else if (0 == strcmp(argv[2], "category"))validate_detector_category(datacfg, cfg, weights, outfile);else if(0==strcmp(argv[2], "demo")) {list *options = read_data_cfg(datacfg);int classes = option_find_int(options, "classes", 20);char *name_list = option_find_str(options, "names", "data/names.list");char **names = get_labels(name_list);demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, hier_thresh);}}

执行命令:

"detector" "category" "E:/projects&code/darknet_yolo/cfg/voc.data" "E:/projects&code/darknet_yolo/cfg/tiny-yolo-voc.cfg" "E:/projects&code/darknet_yolo/tiny-yolo-voc.weights"

result目录下会生成各类物体的val结果,有多少种物体,就会生成多少个txt文件,每个txt文件中有path, class_id, correct, prob, best_iou, xmin, ymin, xmax, ymax信息。

接下来使用evalute.py工具可以解析这些txt文件做一个总结性的评估,evalute.py脚本如下:

# coding=utf-8# 本工具和category命令结合使用# category是在detector.c中新增的命令,主要作用是生成每类物体的evalute结果# 执行命令 ./darknet detector category cfg/paul.data cfg/yolo-paul.cfg backup/yolo-paul_final.weights# result目录下会生成各类物体的val结果,将本工具放在result目录下执行,会print出各种物体的evalute结果,包括# id,avg_iou,avg_correct_iou,avg_precision,avg_recall,avg_score# result目录下会生成low_list和high_list,内容分别为精度和recall未达标和达标的物体种类import osfrom os import listdir, getcwdfrom os.path import joinimport shutil# 共有多少类物体class_num = 20# 每类物体的验证结果class CategoryValidation:id = 0  # Category idpath = ""  # pathtotal_num = 0  # 标注文件中该类bounding box的总数proposals_num = 0  # validate结果中共预测了多少个该类的bounding boxcorrect_num = 0  # 预测正确的bounding box(与Ground-truth的IOU大于0.5且种类正确)的数量iou_num = 0  # 所有大于0.5的IOU的数量iou_sum = 0  # 所有大于0.5的IOU的IOU之和correct_iou_sum = 0  # 预测正确的bounding box的IOU之和score_sum = 0  # 所有正确预测的bounding box的概率之和avg_iou = 0  # 无论预测的bounding box的object的种类是否正确,所有bounding box 与最吻合的Ground-truth求出IOU,对大于0.5的IOU求平均值:avg_iou = iou_sum/iou_numavg_correct_iou = 0  # 对预测正确的bounding box的IOU求平均值:avg_correct_iou = correct_iou_sum/correct_numavg_precision = 0  # avg_precision = correct_num/proposals_numavg_recall = 0  # avg_recall = correct_num/total_numavg_score = 0  # avg_score=score_sum/correct_numdef __init__(self, path, val_cat_num):self.path = pathf = open(path)for line in f:temp = line.rstrip().replace(' ', '').split(',', 9)temp[1] = int(temp[1])self.id = temp[1]self.total_num = val_cat_num[self.id]if (self.total_num):breakfor line in f:# path, class_id, correct, prob, best_iou, xmin, ymin, xmax, ymaxtemp = line.rstrip().split(', ', 9)temp[1] = int(temp[1])temp[2] = int(temp[2])temp[3] = float(temp[3])temp[4] = float(temp[4])self.proposals_num = self.proposals_num + 1.00if (temp[2]):self.correct_num = self.correct_num + 1.00self.score_sum = self.score_sum + temp[3]self.correct_iou_sum = self.correct_iou_sum + temp[4]if (temp[4] > 0.5):self.iou_num = self.iou_num + 1self.iou_sum = self.iou_sum + temp[4]self.avg_iou = self.iou_sum / self.iou_numself.avg_correct_iou = self.correct_iou_sum / self.correct_numself.avg_precision = self.correct_num / self.proposals_numself.avg_recall = self.correct_num / self.total_numself.avg_score = self.score_sum / self.correct_numf.close()# 导出识别正确的图片列表def get_correct_list(self):f = open(self.path)new_f_name = "correct_list_" + self.id + ".txt"new_f = open(new_f_name, 'w')for line in f:temp = line.rstrip().split(', ', 9)if (temp[2]):new_f.write(line)f.close()# 导出识别错误的图片列表def get_error_list(self):f = open(self.path)new_f_name = "error_list_" + self.id + ".txt"new_f = open(new_f_name, 'w')for line in f:temp = line.rstrip().split(', ', 9)if (temp[2] == 0):new_f.write(line)f.close()def print_eva(self):print("id=%d, avg_iou=%f, avg_correct_iou=%f, avg_precision=%f, avg_recall=%f, avg_score=%f \n" % (self.id,   self.avg_iou,   self.avg_correct_iou,   self.avg_precision,   self.avg_recall,   self.avg_score))def IsSubString(SubStrList, Str):flag = Truefor substr in SubStrList:if not (substr in Str):flag = Falsereturn flag# 获取FindPath路径下指定格式(FlagStr)的文件名列表def GetFileList(FindPath, FlagStr=[]):import osFileList = []FileNames = os.listdir(FindPath)if (len(FileNames) > 0):for fn in FileNames:if (len(FlagStr) > 0):if (IsSubString(FlagStr, fn)):FileList.append(fn)else:FileList.append(fn)if (len(FileList) > 0):FileList.sort()return FileList# 获取所有物体种类的ROI数目# path是图片列表的地址# 返回值是一个list,list的索引是物体种类在yolo中的id,值是该种物体的ROI数量def get_val_cat_num(path):val_cat_num = []for i in range(0, class_num):val_cat_num.append(0)f = open(path)for line in f:label_path = line.rstrip().replace('images', 'labels')label_path = label_path.replace('JPEGImages', 'labels')label_path = label_path.replace('.jpg', '.txt')label_path = label_path.replace('.JPEG', '.txt')label_list = open(label_path)for label in label_list:temp = label.rstrip().split(" ", 4)id = int(temp[0])val_cat_num[id] = val_cat_num[id] + 1.00label_list.close()f.close()return val_cat_num# 获取物体名list# path是物体名list文件地址# 返回值是一个列表,列表的索引是类的id,值为该类物体的名字def get_name_list(path):name_list = []f = open(path)for line in f:   # temp = line.rstrip().split(',', 2)temp = linename_list.append(temp[1])return name_listwd = getcwd()val_result_list = GetFileList(wd, ['txt'])val_cat_num = get_val_cat_num("E:/ImageSets/VOCdevkit/VOC2012/2012_val.txt")name_list = get_name_list("E:/projects&code/darknet_yolo/data/voc.txt")low_list = open("low_list.log", 'w')high_list = open("high_list.log", 'w')for result in val_result_list:cat = CategoryValidation(result, val_cat_num)cat.print_eva()if ((cat.avg_precision < 0.3) | (cat.avg_recall < 0.3)):low_list.write("id=%d, name=%s, avg_precision=%f, avg_recall=%f \n" % (cat.id, name_list[cat.id], cat.avg_precision, cat.avg_recall))if ((cat.avg_precision > 0.6) & (cat.avg_recall > 0.6)):high_list.write("id=%d, name=%s, avg_precision=%f, avg_recall=%f \n" % (cat.id, name_list[cat.id], cat.avg_precision, cat.avg_recall))low_list.close()high_list.close()

将本工具放在result目录下执行,会print出各种物体的evalute结果,包括id,avg_iou,avg_correct_iou,avg_precision,avg_recall,avg_score。

id=0, avg_iou=0.632979, avg_correct_iou=0.632979, avg_precision=0.619048, avg_recall=0.702703, avg_score=0.734685 id=1, avg_iou=0.656112, avg_correct_iou=0.661061, avg_precision=0.589744, avg_recall=0.575000, avg_score=0.779845 id=2, avg_iou=0.662430, avg_correct_iou=0.663795, avg_precision=0.620253, avg_recall=0.662162, avg_score=0.670480 id=3, avg_iou=0.628282, avg_correct_iou=0.628282, avg_precision=0.415385, avg_recall=0.397059, avg_score=0.650444 id=4, avg_iou=0.661582, avg_correct_iou=0.665570, avg_precision=0.236364, avg_recall=0.156627, avg_score=0.664535 id=5, avg_iou=0.667193, avg_correct_iou=0.661994, avg_precision=0.526316, avg_recall=0.625000, avg_score=0.676449 id=6, avg_iou=0.624276, avg_correct_iou=0.625075, avg_precision=0.384181, avg_recall=0.412121, avg_score=0.647013 id=7, avg_iou=0.652051, avg_correct_iou=0.653301, avg_precision=0.666667, avg_recall=0.845070, avg_score=0.683803 id=8, avg_iou=0.626261, avg_correct_iou=0.624698, avg_precision=0.326389, avg_recall=0.361538, avg_score=0.657096 id=9, avg_iou=0.651088, avg_correct_iou=0.643851, avg_precision=0.518519, avg_recall=0.700000, avg_score=0.658830 id=10, avg_iou=0.592246, avg_correct_iou=0.584612, avg_precision=0.160000, avg_recall=0.210526, avg_score=0.709824 id=11, avg_iou=0.646738, avg_correct_iou=0.644954, avg_precision=0.567568, avg_recall=0.724138, avg_score=0.692331 id=12, avg_iou=0.647156, avg_correct_iou=0.651284, avg_precision=0.680000, avg_recall=0.755556, avg_score=0.770852 id=13, avg_iou=0.640733, avg_correct_iou=0.641990, avg_precision=0.614035, avg_recall=0.636364, avg_score=0.658294 id=14, avg_iou=0.636807, avg_correct_iou=0.637161, avg_precision=0.606667, avg_recall=0.688351, avg_score=0.633678 id=15, avg_iou=0.631992, avg_correct_iou=0.631992, avg_precision=0.327869, avg_recall=0.317460, avg_score=0.593521 id=16, avg_iou=0.613670, avg_correct_iou=0.626057, avg_precision=0.300000, avg_recall=0.545455, avg_score=0.653449 id=17, avg_iou=0.610414, avg_correct_iou=0.611625, avg_precision=0.477273, avg_recall=0.656250, avg_score=0.679039 id=18, avg_iou=0.642675, avg_correct_iou=0.642675, avg_precision=0.736842, avg_recall=0.608696, avg_score=0.700961 id=19, avg_iou=0.637432, avg_correct_iou=0.640944, avg_precision=0.395833, avg_recall=0.441860, avg_score=0.657191 

同时result目录下会生成low_list和high_list,内容分别为精度和recall未达标和达标的物体种类。

参考博文:http://blog.csdn.net/hrsstudy/article/details/65644517?utm_source=itdadao&utm_medium=referral

原创粉丝点击