机器学习(3)——KNN算法及手写数字的识别(一)

来源:互联网 发布:淘宝直通车的规则 编辑:程序博客网 时间:2024/05/22 11:28

机器学习——KNN算法及手写数字的识别(一)

邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

搬出一张最常见的图,来直观的展示什么是KNN算法:

  kNN比较好理解,其一般过程如下:

    对未知类别属性的数据集中的每个点依次执行以下操作:

1、计算已知类别数据集中的点与当前点之前的距离

2、按照距离递增次序排序

3、选取与当前点距离最小的k个点

4、确定前k个点所在类别的出现概率

5、返回前k个点出现频率最高的类别作为当前点的预测分类

 

下面我们给出一个利用kNN算法实现手写数字识别的例子,这个例子在Machine Learning in Action一书中是用Python描述的,这里为了加深理解我用C++进行重写。


1、读取指定文件夹下的所有训练样本文件名:

// path: 路径, files 文件名, format 文件格式 [3/7/2015 pan]void GetAllFormatFiles(string path, vector<string>& files, string format){//文件句柄    long   hFile = 0;//文件信息    struct _finddata_t fileinfo;string p;if ((hFile = _findfirst(p.assign(path).append("\\*" + format).c_str(), &fileinfo)) != -1){do{if ((fileinfo.attrib &  _A_SUBDIR)){if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0){//files.push_back(p.assign(path).append("\\").append(fileinfo.name) );  GetAllFormatFiles(p.assign(path).append("\\").append(fileinfo.name), files, format);}}else{files.push_back(p.assign(path).append("\\").append(fileinfo.name));}} while (_findnext(hFile, &fileinfo) == 0);_findclose(hFile);}}


2、返回原始队列索引的排序算法

vector<int> insertSort(vector<int> nums){vector<int> sortedIndx;for (int i = 0; i<nums.size(); i++){sortedIndx.push_back(i);}for (int j = 1; j < nums.size(); j++){int key = nums[j];int indx = sortedIndx[j];int i = j - 1;while (i>=0&&nums[i]>key){nums[i + 1] = nums[i];sortedIndx[i + 1] = sortedIndx[i];i--;}nums[i + 1] = key;sortedIndx[i + 1] = indx;}return sortedIndx;}


3、读取文本文件内容,存入一维向量

vector<int> imageToVector(string fileName){vector<int> returnVector;fstream infile;infile.open(fileName, ios::in);while (!infile.eof()){char buffer[256];infile.getline(buffer, 256);for (int i = 0; i < 32; i++){returnVector.push_back(buffer[i] - 48);// 字符转int [3/4/2015 pan]}}return returnVector;}


4KNN算法的实现

// inX 待分类向量,dataSet 训练数据集,labels 训练数据的类别(0,1,2,3,4,5,6,7,8,9),k  [3/4/2015 pan]int classify(vector<int> inX, vector<vector<int>> dataSet, vector<int> labels, int k){int dataSetSize = dataSet.size();int labelsum = 0;vector<int> distances;for (int i = 0; i < dataSetSize; i++){int sum = 0;for (int j = 0; j < inX.size(); j++){int tmp = inX[j] - dataSet[i][j];tmp *= tmp;sum += tmp;}sum=sqrt(sum);distances.push_back(sum);}vector<int> sortedDistIndix;sortedDistIndix = insertSort(distances);for (int i = 0; i < k; i++){labelsum += labels[sortedDistIndix[i]];}return labelsum / k + 0.5;}


5、最终测试

int _tmain(int argc, _TCHAR* argv[]){string path = "G:\\A编程练习\\机器学习&Python\\handWritingTest\\digits\\trainingDigits";string format = "txt";vector<string> files;GetAllFormatFiles(path, files, format);fstream infile;vector<vector<int>> traingMat;vector<int> labels;for (int i = 0; i < files.size(); i++){//cout << path.size();string str;str.assign(files[i], path.size()+1,1);const char *c = str.c_str();labels.push_back(*c - 48);traingMat.push_back(imageToVector(files[i]));}string testFileName = "G:\\A编程练习\\机器学习&Python\\handWritingTest\\digits\\testDigits\\8_73.txt";vector<int> inX = imageToVector(testFileName);int result = classify(inX, traingMat, labels, 8);return 0;}


手写数字样本:

样本下载链接:

http://download.csdn.net/detail/panan160/8480055


0 0
原创粉丝点击