机器学习实战——K-近邻算法【1:从文本中解析数据并可视化】

来源:互联网 发布:淘宝卖家手写卡片内容 编辑:程序博客网 时间:2024/05/21 04:23

机器学习实战学习笔记系列

机器学习实战——K-近邻算法【1:从文本中解析数据并可视化】

《机器学习实战》书中代码有一个小错误,直接运行时会报错。
ValueError: invalid literal for int() with base 10: ‘largeDoses’
因为根据源码不能直接读取字符型,应该将测试数据中最后一列字符型改为int型,这里作者已经修改了,并提供了修改后的数据datingTestSet2.txt文件,所以调用测试时
datingDataMat, datingLabels = kNN.file2matrix(‘datingTestSet.txt’)应该改为
datingDataMat, datingLabels = kNN.file2matrix(‘datingTestSet2.txt’).

源代码:

from numpy import *import operatorfrom os import listdirdef classify0(inX, dataSet, labels, k):    dataSetSize = dataSet.shape[0]    diffMat = tile(inX, (dataSetSize,1)) - dataSet    sqDiffMat = diffMat**2    sqDistances = sqDiffMat.sum(axis=1)    distances = sqDistances**0.5    sortedDistIndicies = distances.argsort()         classCount={}              for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def createDataSet():    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B']    return group, labelsdef file2matrix(filename):    fr = open(filename)    numberOfLines = len(fr.readlines())         #get the number of lines in the file    returnMat = zeros((numberOfLines,3))        #prepare matrix to return    classLabelVector = []                       #prepare labels return       fr = open(filename)    index = 0    for line in fr.readlines():        line = line.strip()        listFromLine = line.split('\t')        returnMat[index,:] = listFromLine[0:3]        classLabelVector.append(int(listFromLine[-1]))        index += 1    return returnMat,classLabelVector

可视化数据

>>> import matplotlib>>> import matplotlib.pyplot as plt>>> fig = plt.figure()>>> ax = fig.add_subplot(111)>>> ax.scatter(datingDataMat[:,1], datingDataMat[:,2])<matplotlib.collections.PathCollection object at 0x00000264FFDE3668>>>> plt.show()

这里写图片描述

为了标记不同样本的分了,我们可以调用Scatter()模块,通过不同的颜色标记出来。由于Scatter函数中调用了array(),所以我们要首先导入numpy模块

from numpy import *>>>ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))<matplotlib.collections.PathCollection object at 0x00000264FFE16BA8>>>> plt.show()

这里写图片描述

可以看出图中有青色,紫色和黄色三种颜色分别代表了三种类别

阅读全文
1 0