[机器学习实战] k-近邻算法

来源:互联网 发布:软件架构师书籍 编辑:程序博客网 时间:2024/06/02 06:00

原理

存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k各最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

k-近邻算法的优缺点

有点:精度高、对异常值不敏感、无数据输入假定
缺点:计算复杂度高、空间复杂度高
使用数据范围:数值型和标称型
通常k是不大于20的整数

k-近邻算法的一般流程

(1)收集数据:可以使用任何方法
(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式
(3)分析数据:可以使用任何方法
(4)训练算法:此步骤不适用于k-近邻算法
(5)测试算法:计算错误率
(6)使用算法:首选需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理

k值的选择

k值越小,整体模型变得越复杂,预测结果对近邻的实例点敏感,容易发生过拟合。k值越大,模型变得简单,可以减小学习的估计误差,但学习的近似误差会增大。在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。

常用函数

(1)对arr重复x行y列构成新的arr
      tile(arr, (x, y))
(2)对arr重复x列构成新的arr
     tile(arr, x)
(3)对矩阵纵向上求和
     mat.sum(axis=0)
(4)对矩阵横向求和
    mat.sum(axis=1)  
(5)对dict排序,选择第1列作为key(下标从0开始)
    sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
(6)对array进行排序,返回排序后的下标数组
    array.argsort()
(7)重新加载模块,模块有更新的情况下
    reload(module)
(8)对每一列,取最小值,形成新的array
    array.min(0)
(9)显示path目录下的所有文件
    listdir(path)
注意:NumPy库提供的数组操作并不支持Python自带的数组类型,因此在编写代码时要注意不要使用错误的数组类型

样例代码

DataUtil.py
1. 用于随机生成数据集
2. 用于随机生成测试向量
3. 用于归一化
4. 用于按照比例随机切分训练集和测试集
# -*- coding: utf-8 -*-from numpy import *class DataUtil:    def __init__(self):        pass    def randomDataSet(self, row, column, classes):        '''rand data set'''        if row <= 0 or column <= 0 or classes <= 0:            return None, None        dataSet = random.rand(row, column)        dataLabel = [random.randint(classes) for i in range(row)]        return dataSet, dataLabel    def file2DataSet(self, filePath):        '''read data set from file'''        f = open(filePath)        lines = f.readlines()        dataSet = None        dataLabel = []        i = 0        for line in lines:            items = line.strip().split('\t')            if dataSet is None:                dataSet = zeros((len(lines), len(items)-1))            dataSet[i,:] = items[0:-1]            dataLabel.append(items[-1])            i += 1        return dataSet, dataLabel    def randomX(self, column):        '''rand a vector'''        return random.rand(1, column)[0]    def norm(self, dataSet):        '''normalize'''        minVals = dataSet.min(0)        maxVals = dataSet.max(0)        ranges = maxVals - minVals        m = dataSet.shape[0]        return (dataSet - tile(minVals, (m, 1)))/tile(ranges, (m, 1))    def spitData(self, dataSet, dataLabel, ratio):        '''split data with ratio'''        totalSize = dataSet.shape[0]        trainingSize = int(ratio*totalSize)        testingSize = totalSize - trainingSize        # random data        trainingSet = zeros((trainingSize, dataSet.shape[1]))        trainingLabel = []        testingSet = zeros((testingSize, dataSet.shape[1]))        testingLabel = []        trainingIndex = 0        testingIndex = 0        for i in range(totalSize):            r = random.randint(1, totalSize)            if (r <= trainingSize and trainingIndex < trainingSize) or testingIndex >= testingSize:                trainingSet[trainingIndex,:] = dataSet[i,:]                trainingLabel.append(dataLabel[i])                trainingIndex += 1            else:                testingSet[testingIndex,:] = dataSet[i,:]                testingLabel.append(dataLabel[i])                testingIndex += 1        return trainingSet, trainingLabel, testingSet, testingLabel

kNN.py
1. k-近邻算法的实现
# -*- coding: utf-8 -*-import operatorfrom numpy import *class kNN:    def __init__(self):        pass    def classify(self, dataSet, dataLabel, vectorX, k):        # data validate        (row, column) = dataSet.shape        if row <= 0 or column <= 0 or row != len(dataLabel) or column != len(vectorX) or k <= 0:            return None, None        # calculate distance and sort        dataX = tile(vectorX, (row, 1))        distance = (((dataX - dataSet)**2).sum(axis=1))**0.5        sortedIndice = distance.argsort()        # classify        classCount = {}        for i in range(k):            if i >= row:                break            label = dataLabel[sortedIndice[i]]            classCount[label] = classCount.get(label, 0) + 1        # sort and return result        return distance, sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)[0][0]

Test4knn.py
1. 用于测试k-近邻算法
# -*- coding: utf-8 -*-from com.fighting.util.DataUtil import *from com.fighting.knn.kNN import *import matplotlib.pyplot as pltdef knn():    '''test knn'''    row, column, classes, k = (100, 5, 3, 10)    # load data set    dataUtil = DataUtil()    dataSet, dataLabel = dataUtil.randomDataSet(row, column, classes)    print 'dataSet: '    print dataSet    print 'dataLabel: '    print dataLabel    # normalize    dataSet = dataUtil.norm(dataSet)    print 'norm-dataSet:'    print dataSet    # plot the data    fig = plt.figure()    ax = fig.add_subplot(111)    ax.scatter(dataSet[:,0], dataSet[:,1], 15*array(dataLabel), 15*array(dataLabel))    plt.show()    # random vector X    vectorX = dataUtil.randomX(dataSet.shape[1])    print 'vectorX: '    print vectorX    # classify    knn = kNN()    distance, clz = knn.classify(dataSet, dataLabel, vectorX, k)    print 'distance: '    print distance    print 'clz=%d' % clzdef dating():    '''test dating classify'''    # load data set    dataUtil = DataUtil()    dataSet, dataLabel = dataUtil.file2DataSet('../../../datasets/knn/datingTestSet.txt')    dataSet = dataUtil.norm(dataSet)    # split training set and testing set    ratio = 0.8    trainingSet, trainingLabel, testingSet, testingLabel = dataUtil.spitData(dataSet, dataLabel, ratio)    testingSize = testingSet.shape[0]    # training and testing    knn = kNN()    for k in range(1, 11):        error = 0        for i in range(testingSize):            distance, clz = knn.classify(trainingSet, trainingLabel, testingSet[i,], k)            if clz != testingLabel[i]:                error += 1        print '%d, %.2f' % (k, error*1.0/testingSize)def f2d():    '''test file2dataset'''    dataUtil = DataUtil()    dataSet, dataLabel = dataUtil.file2DataSet('../../../datasets/knn/datingTestSet.txt')    print 'dataSet:'    print dataSet    print 'dataLabel:'    print dataLabelif __name__ == '__main__':    knn()    #dating()    #f2d()



原创粉丝点击