Machine Learning in action:k-Nearest Neighbor
来源:互联网 发布:java面试宝典2017版 编辑:程序博客网 时间:2024/05/19 23:04
K最近邻(k-Nearest Neighbor,KNN)思想
k最近邻(k-Nearest Neighbor,KNN)算法的核心思想:如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别[1] 。
其决策过程如图1所示。可以简单理解为:当测试样本要归类时,以该样本为中心绘制包含k个样本的圆,这样,数一数圆内已知样本数量最多的种类即为该测试样本的类别。
图1:kNN(ref:百度百科)
算法
- 准备数据,对数据进行预处理
- 选用合适的数据结构存储训练数据和测试元组
- 设定参数,如k
- 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。
- 随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
- 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
- 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
- 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
- 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。
示例1:使用k-近邻算法改进约会网站的配对效果
问题提出
某女使用约会网站找男神,发现男神有三种:不喜欢的(didntLike),魅力一般的(smallDoses),极具魅力的(largeDoses)。从哪些特征来区分这三种人呢,该女使用了以下特征
- 每年获得的飞行常客里程数
- 玩视频游戏所耗时间百分比
- 每周吃的冰激凌公升数
样本示例:
409208.3269760.953952largeDoses144887.1534691.673904smallDoses260521.4418710.805124didntLike7513613.1473940.428964didntLike383441.6697880.134296didntLike7299310.1417401.032955didntLike359486.8307921.213192largeDoses4266613.2763690.543880largeDoses674978.6315770.749278didntLike3548312.2731691.508053largeDoses502423.7234980.831917didntLike632758.3858791.669485didntLike55694.8754350.728658smallDoses510524.6800980.625224didntLike7737215.2995700.331351didntLike436731.8894610.191283didntLike613647.5167541.269164didntLike6967314.2391950.261333didntLike156690.0000001.250185smallDoses
样本点此下载
具体的算法描述
- 收集数据: 提供文本文件。
- 准备数据: 使用python解析文本文件。
- 分析数据:使用Matplotlib画二维扩散图。
- 训练算法:此步驟不适用于k近邻算法。
- 测试算法:使用海伦提供的部分数据作为测试样本。测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
- 使用算法:产生简单的命令行程序,然后海伦可以输入一些特征数据以判断对方是否为自己喜欢的类型。
python kNN实现代码
######################################### # kNN: k Nearest Neighbors # Input: newInput: vector to compare to existing dataset (1xN) # dataSet: size m data set of known vectors (NxM) # labels: data set labels (1xM vector) # k: number of neighbors to use for comparison # Output: the most popular class label # Author: Machine Learning in Aciton# From :http://blog.csdn.net/zouxy09/article/details/16955347######################################### from numpy import *import operator# classify using kNN def kNNClassify(newInput, dataSet, labels, k): numSamples = dataSet.shape[0] # shape[0] stands for the num of row # step 1: calculate Euclidean distance # tile(A, reps): Construct an array by repeating A reps times # the following copy numSamples rows for dataSet diff = tile(newInput, (numSamples,1)) - dataSet # subtract element-wise squaredDiff = diff ** 2 # squared for the subtract squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row distance = squaredDist ** 0.5 ## step 2: sort the distance # argsort() returns the indices that would sort an array in a ascending order sortedDistIndices = argsort(distance) # indice in ascending order classCount = {} # define a dictionary (can be append element) for i in xrange(k): # xrange is similar with range, but make generator ## step 3: choose the min k distance voteLabel = labels[sortedDistIndices[i]] ## step 4: count the times labels occur # when the key voteLabel is not in dictionary classCount, get() # will return 0 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 ## step 5: the max voted class will return maxCount = 0 for key, value in classCount.items(): if value > maxCount: maxCount = value maxIndex = key return maxIndex # for function loadDataFromFile(fileName), change str to int<strong>def str2unm(s): return {'didntLike':1, 'smallDoses':2, 'largeDoses':3}[s]</strong># Prepare: parsing data from a text file to make a metricsdef loadDataFromFile(fileName): fileOpen = open(fileName) arrayOfLines = fileOpen.readlines() # read all content once and make content to list numberOfLines = len(arrayOfLines) dataMetrics = zeros((numberOfLines,3)) # initialize the metrics N*3 classLabelVector = [] index = 0 for line in arrayOfLines: line = line.strip() # delete blank character listFromLine = line.split('\t') # separation using tab dataMetrics[index, :] = listFromLine[0:3] # put each line in data metrics classLabelVector.append(str2unm(listFromLine[-1])) # put the last row in Vector index = index + 1 return dataMetrics, classLabelVector#Prepare: normalizing numeric values def normalizeFeatures(dataSet): minValue = dataSet.min(0) #(0) means take the minimums from the columns maxValue = dataSet.max(0) valueRange = maxValue - minValue normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] normDataSet = dataSet - tile(minValue, (m,1)) normDataSet = normDataSet/tile(valueRange, (m,1)) return normDataSet, valueRange, minValue#Test: testing the classifier as a whole programdef datingClassfierTest(): hoRatio = 0.1 datingDataMat,datingLabels = loadDataFromFile('datingTestSet.txt') normMat, ranges, minVals = normalizeFeatures(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = kNNClassify(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) if (classifierResult != datingLabels[i]): errorCount += 1.0 print "the total error rate is: %f" % (errorCount/float(numTestVecs))
样本的可视化:
import kNNimport matplotlibimport matplotlib.pyplot as pltfrom numpy import * datingDataMatirx, datingLabel = kNN.loadDataFromFile("datingTestSet.txt")fig = plt.figure()ax = fig.add_subplot(111)ax.scatter(datingDataMatirx[:,1], datingDataMatirx[:,2], 15.0*array(datingLabel), 15.0*array(datingLabel))plt.show()
图2 样本集的分布
测试代码
import kNNimport matplotlibimport matplotlib.pyplot as pltfrom numpy import * kNN.datingClassfierTest()
测试结果:
the classifier came back with: 3, the real answer is: 3the classifier came back with: 2, the real answer is: 2the classifier came back with: 1, the real answer is: 1the classifier came back with: 1, the real answer is: 1the classifier came back with: 1, the real answer is: 1the classifier came back with: 1, the real answer is: 1the classifier came back with: 3, the real answer is: 3the classifier came back with: 3, the real answer is: 3the classifier came back with: 1, the real answer is: 1the classifier came back with: 3, the real answer is: 3the classifier came back with: 1, the real answer is: 1
the total error rate is: 0.050000
使用算法:构建完整的可用系统
根据单个样本特征预测其类型代码
def classifyPerson(): resultList = ['not at all', 'in samll doses', 'in large doses'] try: percentTats = float(raw_input(\ "percentage of time spent playing video games?")) ffMiles = float(raw_input("frequent flier miles earned per year?")) iceCream = float(raw_input("liters of ice cream consumed per year?")) datingDataMat,datingLabels = loadDataFromFile('datingTestSet.txt') normMat, ranges, minVals = normalizeFeatures(datingDataMat) inArr = array([ffMiles, percentTats, iceCream]) classifierResult = kNNClassify((inArr- minVals)/ranges,normMat,datingLabels,3) print "You will probably like this person: ",\ resultList[classifierResult - 1] except: print 'please input numbers'
测试示例:
>>> kNN.classifyPerson()percentage of time spent playing video games?10frequent flier miles earned per year?10000liters of ice cream consumed per year?0.5You will probably like this person: in small doses
示例2:手写识别系统
准备数据:将图像转换为测试向量
样本展示
数据下载
## a handwriting recognition system# read str image data from filedef img2vector(filename): returnVector = zeros((1,1024)) fileOpen = open(filename) for i in range(32): lineStr = fileOpen.readline() for j in range(32): returnVector[0,32*i + j] = int(lineStr[j]) return returnVector
数据测试
def handWritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) testFileList = listdir('testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = kNNClassify(vectorUnderTest, trainingMat, hwLabels, 3) print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr) if (classifierResult != classNumStr): errorCount += 1.0 print "\nthe total number of errors is: %d" % errorCount print "\nthe total error rate is: %f" % (errorCount/float(mTest))
实验结果
调用:kNN.handWritingClassTest()
the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the classifier came back with: 9, the real answer is: 9the total number of errors is: 11the total error rate is: 0.011628
参考文献
- Harrington P.Machine learning in action[M]. Manning Publications Co., 2012.
- 百度百科,kNN
- http://blog.csdn.net/zouxy09/article/details/16955347
0 0
- Machine Learning in action:k-Nearest Neighbor
- Machine Learning In Action -- kNN (k Nearest Neighbors)
- machine learning in action 之二 —— k-Nearest Neighbors
- Machine Learning—k-nearest neighbor classification(k近邻分类)
- Supervised Learning 001: k-Nearest Neighbor
- Supervised Learning 002: k-Nearest Neighbor
- Supervised Learning 003: k-Nearest Neighbor
- machine learning in action
- Machine Learning in Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- Machine Learning In Action
- poj 3977 Subset
- 【网络协议】HTTP协议笔记
- B站挂了,无题YY
- 一个关于socket在非阻塞模式下connect是否成功的例子
- POJ 2185 Milking Grid (二维KMP next数组)
- Machine Learning in action:k-Nearest Neighbor
- Python: socket,address already in use处理
- postgresql sql create table
- 设置Outlook不删除服务器邮件备份
- UVa230 - Borrowers
- Activity软键盘弹出布局调整
- 项目需求实例(JDBC通用查询)
- Android 学习笔记 线程操作 异步任务 AsyncTask
- HDU 1058 Humble Numbers(递推)