【机器学习算法源码阅读】之KNN算法
来源:互联网 发布:淘宝营销活动有哪些 编辑:程序博客网 时间:2024/06/03 18:10
前言:之前学过统计学习这门课,基本上是了解过主流的机器学习算法。但是一直没有自己从程序的角度去深入理解它们。现在准备阅读相关算法的实现源码来进一步理解这些算法。
参考资料:python《机器学习实战》
C++ Shark开源库源码
一.KNN算法原理
KNN算法可以视为是最简单的分类算法。它是一种Lazy learning,并不需要训练出来实际的数学模型,甚至也可以认为这种算法不需要训练的过程。假设我们的训练集里面有x1,x2,….,xn一共n个m维的训练样本,每个样本都有对应的标签y1,y2…,yn。现在给定测试向量t,t也是m维的向量,我们需要做的就是判断t的类别。
KNN算法首先定义一种距离度量标准来度量t和xi的远近程度,最简单的度量标准就是欧氏距离。接下来,需要找出训练集中与t距离最近的k个训练样本,这k个样本各自所属的类别也是已知的。最后,我们选取这k个样本中所属类别最多的类别作为t的类别。KNN算法基于非常朴素的事实:如果两个样本非常相似,那么它们所属的类别也应该基本相同。
KNN算法的优点:简单,精度高
KNN算法的缺点:计算复杂度高,空间复杂度高,当训练样本数目很大的时候难以实现。
二.KNN算法的C++实现(shark库源码分析)
测试源码:
#include <Rng/GlobalRng.h>#include <ReClaM/ArtificialDistributions.h>#include <ReClaM/Dataset.h>#include <ReClaM/KernelNearestNeighbor.h>#include <ReClaM/ClassificationError.h>#include <stdio.h>#include <iostream>using namespace std;int main(){ Rng::seed(10);//初始化随机数种子 double gamma = 0.5; RBFKernel k(gamma); //定义RBF核,使用exp(-parameter(0) * dist2)来归一化距离 cout << endl; cout << "*** kernel nearest neighbor classifier ***" << endl; cout << endl; // create the xor problem with uniformly distributed examples unsigned int n = 3; cout << "Generating 100 training and 10000 test examples ..." << flush; Chessboard chess(2, 2); Dataset dataset; dataset.CreateFromSource(chess, 100, 10000); const Array<double>& x = dataset.getTrainingData(); const Array<double>& y = dataset.getTrainingTarget(); cout << " done." << endl; // create the kernel mean classifier cout << "Creating the 3-nearest-neighbor classifier ..." << flush; KernelNearestNeighbor knn(x, y, &k, n); cout << " done." << endl; // estimate the accuracy on the test set cout << "Testing ..." << flush; ClassificationError ce; double acc = 1.0 - ce.error(knn, dataset.getTestData(), dataset.getTestTarget());//执行实际的分类 cout << " done." << endl; cout << "Estimated accuracy: " << 100.0 * acc << "%" << endl << endl; // lines below are for self-testing this example, please ignore if (acc >= 0.92) exit(EXIT_SUCCESS); else exit(EXIT_FAILURE);}
寻找k个近邻的核心算法程序:
double KernelNearestNeighbor::classify(Array<double> pattern){ int i, j, m, u, c, l = training_input.dim(0); double dist2, best; double norm2 = kernel->eval(pattern, pattern); std::vector<int> used; // sorted list of neighbors for (i = 0; i < numberOfNeighbors; i++)//for循环,每次寻找一个最近邻 { // find the nearest neighbor not already in the list best = 1e100; m = 0; for (j = 0; j < i; j++) { u = used[j]; for (; m < u; m++) { dist2 = diag(m) + norm2 - 2.0 * kernel->eval(training_input[m], pattern); if (dist2 < best) { best = dist2; c = m; } } m++; } for (; m < l; m++) { dist2 = diag(m) + norm2 - 2.0 * kernel->eval(training_input[m], pattern); if (dist2 < best) { best = dist2; c = m; } } // insert the nearest neighbor into the sorted list for (j = 0; j < i; j++) if (used[j] >= c) break; if (j == i) used.push_back(c); else used.insert(used.begin() + j, c); } double mean = 0.0; for (i = 0; i < numberOfNeighbors; i++) mean += training_target(used[i], 0); return (mean > 0.0) ? 1.0 : -1.0;}
shark库实现的寻找K近邻算法的复杂度是O(k*n),每次需要遍历整个训练集,第一次寻找距离最小的样本,第二次寻找距离第二小的样本,直至寻找出距离第K小的样本。想起之前写的博客《【July程序员编程艺术】之最小的k个数问题》,可以采用最大堆的数据结构,那样可以把复杂度优化到O(n*log k)。
三.KNN算法的python实现(机器学习实战源码)
def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] //得到数据集大小 diffMat = tile(inX, (dataSetSize,1)) - dataSet//将输入向量扩展成矩阵,然后减去训练矩阵,得到差值矩阵 sqDiffMat = diffMat**2//差值矩阵每个元素求平方 sqDistances = sqDiffMat.sum(axis=1)//每一行求和,即求解每个训练样本与测试向量的距离 distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() //排序 classCount={} for i in range(k)://进行投票 voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)//对投票结果排序 return sortedClassCount[0][0]
python实现使用了numpy库,因此可以直接使用矩阵进行运算,相比于C++的实现要方便很多。python实现中还对训练样本的特征向量值做了一下归一化,这也是很有意义的操作:
def autoNorm(dataSet): minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] normDataSet = dataSet - tile(minVals, (m,1)) normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide return normDataSet, ranges, minVals
- 【机器学习算法源码阅读】之KNN算法
- 机器学习之kNN算法
- 机器学习之KNN 算法
- 机器学习之KNN算法
- 机器学习之KNN算法
- 机器学习之kNN算法
- 机器学习之knn算法
- 机器学习之KNN 算法
- 机器学习之KNN算法
- 机器学习算法之KNN
- 机器学习之KNN算法
- 机器学习之KNN算法
- 机器学习算法之KNN算法
- 《机器学习》 KNN算法
- 机器学习:KNN算法
- 机器学习-KNN 算法
- 【机器学习】kNN算法
- 机器学习 -- kNN算法
- 随笔--2015.10.22
- cocoapod安装和使用
- AllJoyn SYSTEM ARCHITECTURE
- oracle利用分隔符,组合查询想表达的任何话
- linux系统下的各种串口调试工具
- 【机器学习算法源码阅读】之KNN算法
- HDU1042 N!(java)
- C++ MFC中的CMenu---动态添加菜单/菜单项
- JustifyTextView 解决TextView中英文混排排版问题,android文字排版不齐,
- ListView分析
- jsoup UnsupportedMimeTypeExceptio
- 九度考研真题 浙大 2010-2浙大1006:ZOJ问题
- Chrome 正在等待可用的套字节 问题
- java.lang.reflect.InvocationTargetException