模式识别之(一)SVM的opencv3.0实现

来源:互联网 发布:天刀阿暖捏脸数据 编辑:程序博客网 时间:2024/05/16 01:34

之所以决定写这篇文章呢,不是因为我有多了解SVM的推算公式,以及原理过程,我只是知道SVM的使用条件和分类作用,并且使用opencv完成实现功能。以下我会把原理相关的参考网址发给大家,当然也是为了方便自己整理学习,他们写出来的都是很细致的原理公式,有助于理解学习。

1.参考网址

(1)SVM原理:http://blog.csdn.net/v_july_v/article/details/7624837

这篇博文是关于svm原理的文章,如果需要深入学习,或者粗略了解,都可以参考学习。

(2)opencv2.0和opencv3.0的实现区别:http://www.coin163.com/it/2260012286132365332/svm-opencv

这篇文章主要奖励opencv2.0和opencv3.0的参数设置问题,存在一些使用区别

(3)opencv3.0中的SVM训练 mnist 手写字体识别:http://www.itnose.net/detail/6525586.html

这篇文章里面最可取的就是分析了mnist数据集存在的坑

(4)opencv3的参数详解,很仔细:http://livezingy.com/svm-in-opencv3-1/

这篇文章对于opencv3参数的解析很清晰,一目了然,虽然我在使用时遇到一些问题,现在程序改变之后可以运行,也完成了功能,但是遇到的问题还木有解决。

2.相关的程序

(1)mnist数据集的网址:http://yann.lecun.com/exdb/mnist/

(2)相关程序

//mnist.h

#ifndef MNIST_H  
#define MNIST_H
#include <iostream>
#include <string>
#include <fstream>
#include <opencv2/opencv.hpp>

using namespace cv;
using namespace std;

//小端存储转换
int reverseInt(int i);
//读取image数据集信息
Mat read_mnist_image(const string fileName);
//读取label数据集信息
Mat read_mnist_label(const string fileName);
#endif


//mnist.cpp

#include "mnist.h"

//测试数据个数
int testNum = 10000;
int reverseInt(int i)
{
unsigned char c1, c2, c3, c4;
c1 = i & 255;
c2 = (i >> 8) & 255;//>>表示右移
c3 = (i >> 16) & 255;
c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;//左移
}

Mat read_mnist_image(const string fileName)
{
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
Mat DataMat;
ifstream file(fileName, ios::binary);
if (file.is_open())
{
cout << "成功打开图像集\n";
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_images, sizeof(number_of_images));
file.read((char*)&n_rows, sizeof(n_rows));
file.read((char*)&n_cols, sizeof(n_cols));
cout << "magic_number=/n"<<magic_number << "number_of_images=\n " << number_of_images << " " << n_rows << " " << n_cols << endl;
magic_number = reverseInt(magic_number);
number_of_images = reverseInt(number_of_images);
n_rows = reverseInt(n_rows);
n_cols = reverseInt(n_cols);
cout << "MAGIC NUMBER = " << magic_number<< " ;NUMBER OF IMAGES = " << number_of_images<< " ; NUMBER OF ROWS = " << n_rows << " ; NUMBER OF COLS = " << n_cols << endl;
//-test-
//输出第一张和最后一张图,检测读取数据无误
Mat s = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
Mat e = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
cout << "开始读取Image数据......\n";
DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1);
for (int i = 0; i < number_of_images; i++) 
{
for (int j = 0; j < n_rows * n_cols; j++) 
{
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
float pixel_value = float((temp + 0.0) / 255.0);
DataMat.at<float>(i, j) = pixel_value;

//打印第一张和最后一张图像数据
if (i == 0) 
{
s.at<float>(j / n_cols, j % n_cols) = pixel_value;
}
else if (i == number_of_images - 1) 

e.at<float>(j / n_cols, j % n_cols) = pixel_value;
}
}
}
imshow("first image", s);
imshow("last image", e);
}
file.close();
return DataMat;
}


Mat read_mnist_label(const string fileName) 
{
int magic_number;
int number_of_items;
Mat LabelMat;
ifstream file(fileName, ios::binary);
if (file.is_open())
{
cout << "成功打开标签\n";
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_items, sizeof(number_of_items));
magic_number = reverseInt(magic_number);
number_of_items = reverseInt(number_of_items);
cout << "MAGIC NUMBER = " << magic_number << "  ; NUMBER OF ITEMS = " << number_of_items << endl;
//-test-
//number_of_items = testNum;
//记录第一个label和最后一个label
unsigned int s = 0, e = 0;
cout << "开始读取Label数据......\n";
LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
for (int i = 0; i < number_of_items; i++) 
{
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
//打印第一个和最后一个label
if (i == 0) s = (unsigned int)temp;
else if (i == number_of_items - 1) e = (unsigned int)temp;
}
cout << "first label = " << s << endl;
   cout << "last label = " << e << endl;
}
file.close();
return LabelMat;
}


//mnist_detection

#include "mnist.h" 
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp>
#include <string>
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;
string testImage = "t10k-images.idx3-ubyte";
string testLabel = "t10k-labels.idx1-ubyte";
string trainImage = "train-images.idx3-ubyte";
string trainLabel = "train-labels.idx1-ubyte";


int main()
{
//读取训练数据
Mat trainData;
Mat labels;
trainData = read_mnist_image(trainImage);
labels = read_mnist_label(trainLabel);
cout << " 训练数据行数:" << trainData.rows << "训练数据列数:" << trainData.cols << endl;
cout << " 训练标签行数:" << labels.rows << "训练数据列数: " << labels.cols << endl;
cout << "训练数据读取完成" << endl;
//训练参数
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setKernel(SVM::RBF);
svm->setGamma(0.01);
svm->setC(10.0);
svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
cout << "参数设置完成" << endl;
//训练分类器
cout << "开始训练分类器" << endl;
svm->train(trainData, ROW_SAMPLE, labels);
cout << "分类器训练完成" << endl;
//保存训练器
//svm->save("mnist_dataset/mnist_svm.xml");
//cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
//下载分类器
//cout << "开始导入SVM文件...\n";
//Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_dataset/mnist_svm.xml");
//cout << "成功导入SVM文件...\n";
//读取测试数据
Mat testData;
Mat tLabel;
testData = read_mnist_image(testImage);
tLabel = read_mnist_label(testLabel);
cout << "测试数据读取完成" << endl;

float count = 0;
for (int i = 0; i < testData.rows; i++) 
{
Mat sample = testData.row(i);
float res=0.0;
res=svm->predict(sample);
res = abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
count += res;
}
cout << "正确的识别个数: " << count << endl;
cout << "错误率为:" << (10000 - count + 0.0) / 10000 * 100.0 << "%\n";
system("pause");
return 0;
}

(3)遇到问题

①当参数采用这种方式来设置的时候出现无Params这个类的报错

ml::SVM::Params params;params.svmType = ml::SVM::C_SVC;params.kernelType = ml::SVM::POLY;params.gamma = 3;

解决方案,采用如下方式进行参数赋值:

svm->setType(SVM::C_SVC);
svm->setKernel(SVM::RBF);
svm->setGamma(0.01);
svm->setC(10.0);
svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));

②训练好的svm保存后再下载使用为空

//保存训练器
svm->save("mnist_dataset/mnist_svm.xml");
cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
//下载分类器
cout << "开始导入SVM文件...\n";
Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_dataset/mnist_svm.xml");
cout << "成功导入SVM文件...\n";

解决方案,不保存下载,直接使用训练好的svm进行预测:

res=svm->predict(sample);



1 0
原创粉丝点击