机器学习算法(1) KNN
来源:互联网 发布:网络上找不到打印机 编辑:程序博客网 时间:2024/06/05 09:34
这一段时间在学习《Machine Learning in Action》这本书。本篇作为学习总结,简单介绍了KNN算法的实现原理,并且实现‘约会对象的预测’和‘手写数字识别’两个小例子。
例子来自《Machine Learning in Action》 Peter Harrington
算法概述
KNN算法是利用测量不同特征值之间的距离实现分类的算法。存在一个样本数据集合(训练样本)。每个数据都存在标签,即我们知道数据所属于的类型。当输入新的数据后(没有标签),将新数据的每个特征和样本集数据进行特征对比,然后提取样本中特征最相似的数据的分类标签,作为新数据的分类标签。
预测约会对象
问题概述
收集到三种类型的交往对象(不喜欢/一般/喜欢)及其数据,每个对象对应有三种特征值:
* 每年飞行里程数
* 玩游戏时间百分比
* 每周消费冰激凌质量
同过以上数据预测一个新的数据样本属于何种分类。
数据集
datingTestSet2.txt
42666 13.276369 0.543880 367497 8.631577 0.749278 135483 12.273169 1.508053 350242 3.723498 0.831917 163275 8.385879 1.669485 1... ...
读取数据
从数据文本中读取训练数据,把特征值和标签分别提取,存在returnMat
,classLabelVector
中。
def file2matrix(filename): fr = open(filename) numberOfLines = len(fr.readlines()) # 获得文件的行数 returnMat = zeros((numberOfLines,3)) classLabelVector = [] fr = open(filename) index = 0 for line in fr.readlines(): line = line.strip() # 去掉空格末尾 listFromLine = line.split('\t') returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector
归一化数据
三个特征值的数量级不同意,之间计算距离会到时每个特征值的权重不同,所以需要做归一化处理
def autoNorm(dataSet): minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] normDataSet = dataSet - tile(minVals, (m,1)) normDataSet = normDataSet/tile(ranges, (m,1)) return normDataSet, ranges, minVals
KNN实现
通过训练数据对新数据点进行预测,计算新店和每一个训练数据点的距离,取最近的k个点。在k个点中统计出现数量最多的类型作为新数据点的类型返回。
def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] # 计算行数 diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 第二列 从大到小排序 return sortedClassCount[0][0]
测试
对已知数据进行KNN分类,对比分类结果和真是类型,计算错误率。
def datingClassTest(): hoRatio = 0.1 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') normMat, ranges, minVals = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])) if (classifierResult != datingLabels[i]): errorCount += 1.0 print ("the total error rate is: %f" % (errorCount/float(numTestVecs))) print (errorCount)
结果
the total error rate is: 0.0500005.0
应用
输入一个含有以上三个特征的数据点,进行预测
def classifyPerson(): resultList = ['not at all','in small','in large doses'] inArr=([72011,4.932976,0.632026]) datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') normMat, ranges, minVals = autoNorm(datingDataMat) classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3) print("You will probably like this person: ",resultList[classifierResult-1])
结果
You will probably like this person: not at all
手写体数字识别
问题概述
采集手写数字的图像,进行灰度处理,得到0-1
表示的数据集合。通过已有的训练数据,对新的数据进行分类,即识别数字。
数据集合
3_25.txt 手写体数字3的数据表示
0000000000000000001010000000000000000000000000011111111000000000000000000011111111111110000000000000000011111111111111110000000000000011111111111111111110000000000001111111111111111111100000000000011111110000000111111000000000000011110000000000111110000000000000000000000000001111100000000000000000000000000011111000000000000000000000000000011111000000000000000000000000011111110000000000000000000000001111111000000000000000000000001111111100000000000000000000000011111110000000000000000000000111111111100000000000000000000111111111111100000000000000000011111111111111000000000000000000111111111111111100000000000000001111111111111111000000000000000001100000001111110000000000000000000000000011111100000000000000000000000000001111000000000000000000000000000011111000000000000000000000000001111110000000000000000000000000011111100000000000000000000000011111111000000000000000010011111111111000000000000000001111111111111110000000000000000011111111111110000000000000000000111111111111100000000000000000000111111000000000000000
读取数据
每一个txt文档储存有一个手写数字图像的0-1
表示形式。将32*32的形式转换成1*1024的向量形式。
def img2vector(filename): returnVect = zeros((1,1024)) fr = open(filename) for i in range(32): lineStr = fr.readline() for j in range(32): returnVect[0,32*i+j] = int(lineStr[j]) return returnVect
测试
将转换好的数据传入之前使用的那个KNN分类函数(classify0()
)中。因为手写体的真是数据结果保存在文件名中,如3_25.txt
表示手写数字3的第25个数据。所以代码中涉及到了字符串提取和分割的内容。
def handwritingClassTest(): hwLabels = [] trainingFileList = listdir('trainingDigits') m = len(trainingFileList) trainingMat = zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) hwLabels.append(classNumStr) trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) testFileList = listdir('testDigits') errorCount = 0.0 mTest = len(testFileList) for i in range(mTest): fileNameStr = testFileList[i] fileStr = fileNameStr.split('.')[0] classNumStr = int(fileStr.split('_')[0]) vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)) if (classifierResult != classNumStr): errorCount += 1.0 print ("\nthe total number of errors is: %d" % errorCount) print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))
结果
the total number of errors is: 10the total error rate is: 0.010571
可以看到KNN算法对手写体数字的识别率也是很高的。
以上完整代码见GitHub。
- 机器学习算法(1) KNN
- 《机器学习》 KNN算法
- 机器学习:KNN算法
- 机器学习-KNN 算法
- 【机器学习】kNN算法
- 机器学习 -- kNN算法
- 机器学习---kNN算法
- 机器学习--kNN算法
- 机器学习--KNN算法
- 机器学习算法-kNN
- 机器学习knn算法
- 机器学习经典算法1--knn
- 机器学习算法(1)-KNN
- 1、机器学习算法KNN -- Java代码
- 机器学习(1)-KNN算法理解
- 机器学习-kNN 算法(1)
- 机器学习1-KNN算法设计part1
- 机器学习算法---kNN算法
- 【1233】推箱子(右)
- android如何让布局保持位于键盘上方(一直在键盘上面)
- GemFile 学习——环境搭建
- mybatis中的#{}和${}区别
- Android中注解的实际运用
- 机器学习算法(1) KNN
- ODBC Excel驱动程序登陆失败
- [ESSAY]what are you optimistic/pessimistic about?
- SystemUI之功能介绍和UI布局介绍
- C++ primer 学习笔记(一个学过谭老师的C++书籍, 并在一年间间断使用过C++的码农)
- BZOJ 3170: [Tjoi 2013]松鼠聚会 切比雪夫距离
- ssl与ssh协议的一些笔记
- Rythm.js 使用教程详解
- windows10安装多个版本的jdk