yolo训练时的训练数据读取代码注释!代码质量很高。

来源:互联网 发布:华中数据交易所 编辑:程序博客网 时间:2024/05/22 00:55

1、yolo训练时的加载图像数据的主要函数就是:load_data_region(),其里面包含了很多图像预处理,像各种曝光度调节、图像crop等操作。具体实现如下:

//data.cdata load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure){    char **random_paths = get_random_paths(paths, n, m);    int i;    data d = {0};    d.shallow = 0;//这里的每一行row存储着一张图片数据,其中n是batch size大小    d.X.rows = n;    d.X.vals = calloc(d.X.rows, sizeof(float*));    d.X.cols = h*w*3;//这里的y是标签    d.y = make_matrix(n, 5*boxes);    for(i = 0; i < n; ++i){//加载图片        image orig = load_image_color(random_paths[i], 0, 0);        int oh = orig.h;        int ow = orig.w;/*这里是为数据添加抖动干扰,提高网络的泛化能力(其实就是crop,数据增广的一种).配置文件的jitter=0.2,则宽高最多裁剪掉或者增加原始宽高的1/5.*/        int dw = (ow*jitter);        int dh = (oh*jitter);//这里进行产生随机值        int pleft  = rand_uniform(-dw, dw);        int pright = rand_uniform(-dw, dw);        int ptop   = rand_uniform(-dh, dh);        int pbot   = rand_uniform(-dh, dh);        int swidth =  ow - pleft - pright;        int sheight = oh - ptop - pbot;//这里计算的比例是为了计算抖动后实际样本的位置区域。        float sx = (float)swidth  / ow;        float sy = (float)sheight / oh;        int flip = random_gen()%2;//对图片进行裁剪        image cropped = crop_image(orig, pleft, ptop, swidth, sheight);        float dx = ((float)pleft/ow)/sx;        float dy = ((float)ptop /oh)/sy;//把裁剪后的图片归一化到416*416        image sized = resize_image(cropped, w, h);        if(flip) flip_image(sized);//对图片进色调、曝光度等的调整        random_distort_image(sized, hue, saturation, exposure);        d.X.vals[i] = sized.data;//对相应数据的标签进行读取和图像坐标进行恢复。其具体实现看其代码实现        fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy);        free_image(orig);free_image(sized);        free_image(cropped);    }    free(random_paths);    return d;}

训练样本坐标恢复代码:

//data.cvoid fill_truth_region(char *path, float *truth, int classes, int num_boxes, int flip, float dx, float dy, float sx, float sy){    char labelpath[4096];    find_replace(path, "images", "labels", labelpath);    find_replace(labelpath, "JPEGImages", "labels", labelpath);    find_replace(labelpath, ".jpg", ".txt", labelpath);    find_replace(labelpath, ".png", ".txt", labelpath);    find_replace(labelpath, ".JPG", ".txt", labelpath);    find_replace(labelpath, ".JPEG", ".txt", labelpath);    int count = 0;    box_label *boxes = read_boxes(labelpath, &count);    randomize_boxes(boxes, count);    correct_boxes(boxes, count, dx, dy, sx, sy, flip);    float x,y,w,h;    int id;    int i;    for (i = 0; i < count; ++i) {        x =  boxes[i].x;        y =  boxes[i].y;        w =  boxes[i].w;        h =  boxes[i].h;        id = boxes[i].id;        if (w < .01 || h < .01) continue;        int col = (int)(x*num_boxes);        int row = (int)(y*num_boxes);        x = x*num_boxes - col;        y = y*num_boxes - row;        int index = (col+row*num_boxes)*(5+classes);        if (truth[index]) continue;        truth[index++] = 1;        if (id < classes) truth[index+id] = 1;        index += classes;        truth[index++] = x;        truth[index++] = y;        truth[index++] = w;        truth[index++] = h;    }    free(boxes);}





























原创粉丝点击