整理使用SVM实现分类的步骤

来源:互联网 发布:游族网络定增价格2017 编辑:程序博客网 时间:2024/06/09 16:12

     之前SVM进行项目中待识别物体的分类,过去挺长时间结果有点生疏了,这里梳理一下。因为当初是根据《深入理解OpenCV》自动分类的例程,所以这里仍以此为例。这里只是列出了SVM使用环节的代码,而且本来就是一个小例子,其中测试集训练集的选取、特征的选择等都是最简单的考虑,真正应用,要进行更加全面的考虑。

    这个例程是从传送带中找到螺母、螺钉和垫圈,使用的特征是外轮廓的长宽比和面积的大小,也就是说,使用这两项就能训练SVM模型预测某一个轮廓属于哪一种类别。

 

#include <iostream>#include <string>#include <sstream>#include <cmath>using namespace std;#include "opencv2/core.hpp"#include "opencv2/imgproc.hpp"#include "opencv2/highgui.hpp"#include "opencv2/core/utility.hpp"#include "opencv2/ml.hpp"using namespace cv;using namespace cv::ml;Mat light_pattern;//创建一个SVM模型Ptr<SVM> svm;/*** Extract the features for all objects in one image  检测一幅图中所有物体的特征* * @param Mat img input image* @param vector<int> left output of left coordinates for each object 某个轮廓中心的x坐标* @param vector<int> top output of top coordintates for each object 某个轮廓中心的y坐标* @return vector< vector<float> > a matrix of rows of features for each object detectedoutput                    第一层有n个元素,n即有多少个零件,第二次有2个元素,分别是面积和长宽比两个特征**/vector< vector<float> > ExtractFeatures(Mat img, vector<int>* left=NULL, vector<int>* top=NULL){    vector< vector<float> > output;    vector<vector<cv::Point> > contours; //存放轮廓数据    Mat input1= img.clone();          //把原图复制一份,然后用这个副本去做轮廓检测,因为findContours函数会改变图像    Mat input;    vector<Vec4i> hierarchy;    cvtColor(input1, input, COLOR_BGR2GRAY);     findContours(input, contours, hierarchy, RETR_CCOMP, CHAIN_APPROX_SIMPLE);    // Check the number of objects detected    if(contours.size() == 0 )    {        return output;    }    RNG rng( 0xFFFFFFFF );    for(int i=0; i<contours.size(); i++)    {        //先生成一个全黑的图像        Mat mask= Mat::zeros(img.rows, img.cols, CV_8UC1);//图像上画出轮廓来,颜色为1        drawContours(mask, contours, i, Scalar(1), FILLED, LINE_8, hierarchy, 1);//sum函数统计当前轮廓对应的图中每个通道的总和,这里取出0通道的非零点的总和        Scalar area_s= sum(mask);        float area= area_s[0];            if(area>500)//满足第一个条件面积条件{             //宽高比    RotatedRect r= minAreaRect(contours[i]);            float width= r.size.width;            float height= r.size.height;            float ar=(width<height)?height/width:width/height;            //output第一层有n个元素,n即有多少个零件,第二次有2个元素,分别是面积和长宽比两个特征    vector<float> row;            row.push_back(area);            row.push_back(ar);            output.push_back(row);            //如果轮廓中心到图像左上角和到图像顶部的坐标忽略了,那就加上            if(left!=NULL)            {                left->push_back((int)r.center.x);            }            if(top!=NULL)    {                top->push_back((int)r.center.y);            }            waitKey(10);          }      }      return output;} //使用光纹删除背景Mat removeLight(Mat img, Mat pattern){    Mat aux;    // Require change our image to 32 float for division    Mat img32, pattern32;    img.convertTo(img32, CV_32F);    pattern.convertTo(pattern32, CV_32F);cout << pattern.channels() << endl;cout << img.channels() << endl;    //imshow("a",pattern32);    //waitKey(0);    //Divide the imabe by the pattern    aux= 1-(img32 / pattern32);    // Scale it to convert o 8bit format    aux=aux*255;    // Convert 8 bits format    aux.convertTo(aux, CV_8U);    //equalizeHist( aux, aux );    return aux;}/**对输入图像,进行滤波,去光照(背景),然后二值化* Preprocess an input image to extract components and stats* @params Mat input image to preprocess* @return Mat binary image*/Mat preprocessImage(Mat input){  Mat result;  // Remove noise  Mat img_noise, img_box_smooth;  medianBlur(input, img_noise, 3);  //Apply the light pattern  Mat img_no_light;  img_noise.copyTo(img_no_light);   img_no_light= removeLight(img_noise, light_pattern);      // Binarize image for segment  threshold(img_no_light, result, 30, 255, THRESH_BINARY);  return result;}/**一次性读取完num_for_test张图片,每读一张照片然后提取特征,放到特征的向量里面这里就是告诉电脑,有个图他是螺母(标签),他的面积特征是多少,长宽比特征是多少,电脑就根据这个去训练模型* Read all images in a folder creating the train and test vectors* @param label 因为螺母螺钉各自在不同文件夹里面,所以每文件夹内图片的标签是一样的,这里用标签0表示螺母,1垫圈,2螺钉* @param number of images used for test and evaluate algorithm error* @param trainingData 训练环节特征提取出来的两个特征,把这两个都放到trainingData这个向量里面* @param reponsesData 训练环节的图片的label,数量是(总数-num_for_test)里面的所有特征数之和 * @param testData     测试环节的特征数据,也是两个特征放到一个里面* @param testResponsesData 存储测试环节的label* @return true if can read the folder images, false in error case**/bool readFolderAndExtractFeatures(string folder, int label, int num_for_test,   vector<float> &trainingData, vector<int> &responsesData,    vector<float> &testData, vector<float> &testResponsesData){    //使用VideoCapture读取文件夹内的图片VideoCapture images;//这个VideoCapture::open可以返回能不能读当前文件夹的图    if(images.open(folder)==false){        cout << "Can not open the folder images" << endl;        return false;    }    Mat frame;    int img_index=0;//VideoCapture::read则是每次从这个文件夹内读一张给frame,读完了这个while也就结束    while( images.read(frame) ){        //调用了图像预处理函数,功能是去光照背景,并二值化        Mat pre= preprocessImage(frame);        // Extract features,提取当前图片的特征,每一幅如螺母的图里面,可能有一个或几个螺母        vector< vector<float> > features= ExtractFeatures(pre);        for(int i=0; i< features.size(); i++){//前num_for_test张图片用来做训练            if(img_index >= num_for_test)    {                trainingData.push_back(features[i][0]);                trainingData.push_back(features[i][1]);                responsesData.push_back(label);                }//后面的用来做测试else{                testData.push_back(features[i][0]);                testData.push_back(features[i][1]);                testResponsesData.push_back((float)label);                }        }        img_index++;    }    return true;  }//创建并训练SVM模型//注意要先创建一个Ptr<SVM> svm;void trainAndTest(){//它们四个的详细意义见readFolderAndExtractFeatures函数    vector< float > trainingData;    vector< int > responsesData;    vector< float > testData;    vector< float > testResponsesData;//把多少张图拿来用作训练集    int num_for_test= 20;    //把螺母螺钉垫片三种图片各自放在三个不同的文件夹里,那么读取不同文件夹时,同时就可以判断是不同的label    readFolderAndExtractFeatures("data\\nut\\tuerca_%04d.pgm", 0, num_for_test, trainingData, responsesData, testData, testResponsesData);    readFolderAndExtractFeatures("data\\ring\\arandela_%04d.pgm", 1, num_for_test, trainingData, responsesData, testData, testResponsesData);    readFolderAndExtractFeatures("data\\screw\\tornillo_%04d.pgm", 2, num_for_test, trainingData, responsesData, testData, testResponsesData);      cout << "Num of train samples: " << responsesData.size() << endl;    cout << "Num of test samples: " << testResponsesData.size() << endl;      // 把向量里面的特征和标签转换为Mat格式以便传给tranning函数    Mat trainingDataMat(trainingData.size()/2, 2, CV_32FC1, &trainingData[0]);    Mat responses(responsesData.size(), 1, CV_32SC1, &responsesData[0]);    Mat testDataMat(testData.size()/2, 2, CV_32FC1, &testData[0]);    Mat testResponses(testResponsesData.size(), 1, CV_32FC1, &testResponsesData[0]);      svm = SVM::create();    svm->setType(SVM::C_SVC);    svm->setKernel(SVM::CHI2);    //设置使用的内核和停止学习过程的标准,最大迭代次数为100    svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));    //trainingData 训练环节特征提取出来的两个特征,把这两个都放到trainingData这个向量里面    //reponsesData 训练环节的图片的label,数量是(总数 - num_for_test)里面的所有特征数之和    svm->train(trainingDataMat, ROW_SAMPLE, responses);//使用剩下的图作为测试集,来评估模型的准确性    if(testResponsesData.size()>0)    {        cout << "Evaluation" << endl;        cout << "==========" << endl;        // Test the ML Model        Mat testPredict;        svm->predict(testDataMat, testPredict);        cout << "Prediction Done" << endl;        //预测结果与测试集的标签不同就说明预测错误        Mat errorMat= testPredict!=testResponses;        float error= 100.0f * countNonZero(errorMat) / testResponsesData.size();        cout << "Error: " << error << "\%" << endl;    }else{        plotTrainData(trainingDataMat, responses);    }}int main(  ){        //pgm是灰度图像格式中一种最简单的格式标准
    //加载待分类的图和背景图
    //背景图是用来进行图像背景的归一化,每次都把图像减去背景    String img_file = "data\\test.pgm";    String light_pattern_file= "data\\pattern.pgm";        //Load image to process    Mat img= imread(img_file, -1);    if(img.data==NULL){        cout << "Error loading image "<< img_file << endl;        return 0;    }    Mat img_output= img.clone();    cvtColor(img_output, img_output, COLOR_GRAY2BGR);    light_pattern= imread(light_pattern_file, 1);     if(light_pattern.data==NULL){        cout << "ERROR: Not light patter loaded" << endl;        return 0;    }    medianBlur(light_pattern, light_pattern, 3);//使用准备好的数据,训练SVM模型    trainAndTest();      //对待分类的图像预处理cvtColor(img, img, COLOR_GRAY2BGR);    Mat pre= preprocessImage(img);    //提取待分类图像的特征,得到待提取图片中每个轮廓(零件)的特征值    vector<int> pos_top, pos_left;    vector< vector<float> > features= ExtractFeatures(pre, &pos_left, &pos_top);    cout << "Num objects extracted features " << features.size() << endl;//对于每个轮廓,使用训练好的SVM预测属于哪一类零件    for(int i=0; i< features.size(); i++){              cout << "Data Area AR: " << features[i][0] << " " << features[i][1] << endl;        Mat trainingDataMat(1, 2, CV_32FC1, &features[i][0]);        cout << "Features to predict: " << trainingDataMat << endl;        float result= svm->predict(trainingDataMat);        cout << result << endl;    }    return 0;}

原创粉丝点击