机器学习实战-KNN算法
来源:互联网 发布:淘宝旺旺名怎么设置 编辑:程序博客网 时间:2024/06/06 07:28
import numpyimport osfrom numpy import arrayfrom numpy import tileimport operatorimport matplotlib.pyplot as plt#数据例子def createDataSet(): group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels=['A','A','B','B'] #标签与点一一对应 return group,labels'''******************************主要分类函数************************************************************'''#取距离最近的k各点, 返回 k个点中频率最多的类别作为分类def classify0(point,dataArray,labels,k):#(测试[,,...] 比较集array 标签集 OneDimension=dataArray.shape[0] tmpArray=(tile(point,(OneDimension,1))-dataArray)**2 #point平铺成二维矩与其计算各点距离 sqrtArray=tmpArray.sum(1) sortedArrayIndex=sqrtArray.argsort()#按索引点排序 -列表 #print(sortedArrayIndex) classCount={} #空字典 for i in range(k): lab=labels[sortedArrayIndex[i]] #取相应索引点的标签 classCount[lab]=classCount.get(lab,0)+1 #字典中有该key则取其映射值(这里为int),否则返回0 sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True) #指明关键字 return sortedClassCount[0][0]#测试'''group0,labels0=createDataSet()print(group0,labels0)print(classify0([0,0],group0,labels0,3))''''''*********************************************************************''''''***************************约会配对分类*******************************''''''*********************************************************************'''#获取文件数据 返回数据Array 标签listdef file2matrix(filename): file=open(filename) fileList=file.readlines()#返回全部行 ,行后有\n---列表 returnMat=numpy.zeros((len(fileList),3)) index=0 labels=[] for st in fileList: st=st.strip()#移除字符串头尾指定的字符(默认为空字符) 这里移除\n strList=st.split('\t')#str.split(sep=None, maxsplit=-1 无限制) returnMat[index,:]=strList[0:3] labels.append(int(strList[-1])) index+=1 return returnMat,labels#数据归一化def Normalize(dataMat): #Array min_value=dataMat.min(0) max_value=dataMat.max(0) range_value=max_value-min_value normMat=dataMat-tile(min_value,(dataMat.shape[0],1)) normMat=normMat/tile(range_value,(dataMat.shape[0],1)) return normMat,range_value,min_value#测试KNN错误率def datingClassTest(): datingData,datingLabels=file2matrix('datingTestSet2.txt') datingData,datingRange,datingMinValue=Normalize(datingData) testnum=int(datingData.shape[0]/10) #100 error_count=0; for i in range(testnum): label=classify0(datingData[i],datingData[testnum:datingData.shape[0]],datingLabels[testnum:datingData.shape[0]],3) if label!=datingLabels[i]: error_count+=1; print('错误率:%f'%(error_count/float(testnum))) #测试'''datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')datingDataMat,datingDataRange,datingDataMin=Normalize(datingDataMat)fg=plt.figure()subfg1=fg.add_subplot(111)subfg1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))#subfg1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),1*array(tile([1],(array(datingLabels).shape[0],1))))plt.xlabel('玩视频耗时百分比')plt.ylabel('周消耗冰激凌公升数')plt.show()datingClassTest()''''''*********************************************************************'''def classifyPerson(): datingData,datingLabels=file2matrix('datingTestSet2.txt') datingData,datingRange,datingMinValue=Normalize(datingData) resultClass=['不喜欢','一般','有魅力'] miles=float(input('每年飞行里程数:')) game=float(input('玩游戏小号百分比:')) ice=float(input('每周冰淇淋公升:')) data=array(([miles,game,ice]-datingMinValue)/datingRange) label=classify0(data,datingData,datingLabels,3) print('类型是:',resultClass[label-1]) #数据的分类标签1,2,3 #测试 '''classifyPerson()''''''*********************************************************************''''''*****************************手写识别*********************************''''''*********************************************************************'''def img2vector(filename): file=open(filename) returnVec=numpy.zeros((1,1024)) for i in range(32): fileString=file.readline() for j in range(32): returnVec[0,i*32+j]=fileString[j] file.close() return returnVecdef handWriteClassTest(): trainList=os.listdir('trainingDigits') DT=len(trainList) trainArray=numpy.zeros((DT,1024)) labels=[] for i in range(DT): filename=trainList[i] labels.append(int(filename[0])) trainArray[i,:]=img2vector('trainingDigits/%s'%filename) testList=os.listdir('trainingDigits') DS=len(testList) error_count=0; for j in range(DS): filename=testList[j] label=int(filename[0]) testArray=img2vector('trainingDigits/%s'%filename) testLabel=classify0(testArray,trainArray,labels,3) if label!=testLabel: error_count+=1 error_rate=error_count/DS print('错误率:%f'%error_rate) #测试 '''handWriteClassTest()'''
阅读全文
1 0
- 机器学习实战-KNN算法
- 机器学习实战 KNN算法
- 《机器学习实战》-- KNN算法
- 机器学习实战 kNN算法
- 机器学习实战-KNN算法
- 机器学习实战--KNN算法
- 机器学习实战-KNN 算法
- 机器学习实战:KNN算法
- 机器学习实战-KNN算法
- 机器学习实战学习笔记-KNN算法
- 机器学习实战之KNN算法详解
- 机器学习实战(第一章)---KNN算法
- 机器学习实战之KNN算法
- 机器学习实战之KNN算法
- 机器学习实战--KNN 算法 笔记
- Python机器学习实战kNN分类算法
- 机器学习实战——kNN算法
- 【机器学习实战】-01KNN近邻算法
- 无限轮播
- 文章标题
- Anaconda (python)安装
- Mecanim动画系统学习(三)
- 典型数据库架构设计与实践 | 架构师之路
- 机器学习实战-KNN算法
- Java 开发中如何正确踩坑
- 【RabbitMQ】work模式
- org.springframework.beans.factory.BeanCreationException: Error creating bean with name 'sessionFacto
- 用 OpenCV 和OpenNI 2输出kinect 的深度、彩色图
- tensorflow实现迁移学习
- multigpu tensorflow
- java字符串分解 StringTokenizer用法
- window openssl 安装与使用