[机器学习实战] k-近邻算法
来源:互联网 发布:软件架构师书籍 编辑:程序博客网 时间:2024/06/02 06:00
原理
存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k各最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
k-近邻算法的优缺点
有点:精度高、对异常值不敏感、无数据输入假定
缺点:计算复杂度高、空间复杂度高
使用数据范围:数值型和标称型
通常k是不大于20的整数
k-近邻算法的一般流程
(1)收集数据:可以使用任何方法
(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式
(3)分析数据:可以使用任何方法
(4)训练算法:此步骤不适用于k-近邻算法
(5)测试算法:计算错误率
(6)使用算法:首选需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理
k值的选择
k值越小,整体模型变得越复杂,预测结果对近邻的实例点敏感,容易发生过拟合。k值越大,模型变得简单,可以减小学习的估计误差,但学习的近似误差会增大。在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。
常用函数
(1)对arr重复x行y列构成新的arr
tile(arr, (x, y))
(2)对arr重复x列构成新的arr
tile(arr, x)
(3)对矩阵纵向上求和
mat.sum(axis=0)
(4)对矩阵横向求和
mat.sum(axis=1)
(5)对dict排序,选择第1列作为key(下标从0开始)
sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
(6)对array进行排序,返回排序后的下标数组
array.argsort()
(7)重新加载模块,模块有更新的情况下
reload(module)
(8)对每一列,取最小值,形成新的array
array.min(0)
(9)显示path目录下的所有文件
listdir(path)
注意:NumPy库提供的数组操作并不支持Python自带的数组类型,因此在编写代码时要注意不要使用错误的数组类型
样例代码
DataUtil.py
1. 用于随机生成数据集
2. 用于随机生成测试向量
3. 用于归一化
4. 用于按照比例随机切分训练集和测试集
# -*- coding: utf-8 -*-from numpy import *class DataUtil: def __init__(self): pass def randomDataSet(self, row, column, classes): '''rand data set''' if row <= 0 or column <= 0 or classes <= 0: return None, None dataSet = random.rand(row, column) dataLabel = [random.randint(classes) for i in range(row)] return dataSet, dataLabel def file2DataSet(self, filePath): '''read data set from file''' f = open(filePath) lines = f.readlines() dataSet = None dataLabel = [] i = 0 for line in lines: items = line.strip().split('\t') if dataSet is None: dataSet = zeros((len(lines), len(items)-1)) dataSet[i,:] = items[0:-1] dataLabel.append(items[-1]) i += 1 return dataSet, dataLabel def randomX(self, column): '''rand a vector''' return random.rand(1, column)[0] def norm(self, dataSet): '''normalize''' minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals - minVals m = dataSet.shape[0] return (dataSet - tile(minVals, (m, 1)))/tile(ranges, (m, 1)) def spitData(self, dataSet, dataLabel, ratio): '''split data with ratio''' totalSize = dataSet.shape[0] trainingSize = int(ratio*totalSize) testingSize = totalSize - trainingSize # random data trainingSet = zeros((trainingSize, dataSet.shape[1])) trainingLabel = [] testingSet = zeros((testingSize, dataSet.shape[1])) testingLabel = [] trainingIndex = 0 testingIndex = 0 for i in range(totalSize): r = random.randint(1, totalSize) if (r <= trainingSize and trainingIndex < trainingSize) or testingIndex >= testingSize: trainingSet[trainingIndex,:] = dataSet[i,:] trainingLabel.append(dataLabel[i]) trainingIndex += 1 else: testingSet[testingIndex,:] = dataSet[i,:] testingLabel.append(dataLabel[i]) testingIndex += 1 return trainingSet, trainingLabel, testingSet, testingLabel
kNN.py
1. k-近邻算法的实现
# -*- coding: utf-8 -*-import operatorfrom numpy import *class kNN: def __init__(self): pass def classify(self, dataSet, dataLabel, vectorX, k): # data validate (row, column) = dataSet.shape if row <= 0 or column <= 0 or row != len(dataLabel) or column != len(vectorX) or k <= 0: return None, None # calculate distance and sort dataX = tile(vectorX, (row, 1)) distance = (((dataX - dataSet)**2).sum(axis=1))**0.5 sortedIndice = distance.argsort() # classify classCount = {} for i in range(k): if i >= row: break label = dataLabel[sortedIndice[i]] classCount[label] = classCount.get(label, 0) + 1 # sort and return result return distance, sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)[0][0]
Test4knn.py
1. 用于测试k-近邻算法
# -*- coding: utf-8 -*-from com.fighting.util.DataUtil import *from com.fighting.knn.kNN import *import matplotlib.pyplot as pltdef knn(): '''test knn''' row, column, classes, k = (100, 5, 3, 10) # load data set dataUtil = DataUtil() dataSet, dataLabel = dataUtil.randomDataSet(row, column, classes) print 'dataSet: ' print dataSet print 'dataLabel: ' print dataLabel # normalize dataSet = dataUtil.norm(dataSet) print 'norm-dataSet:' print dataSet # plot the data fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(dataSet[:,0], dataSet[:,1], 15*array(dataLabel), 15*array(dataLabel)) plt.show() # random vector X vectorX = dataUtil.randomX(dataSet.shape[1]) print 'vectorX: ' print vectorX # classify knn = kNN() distance, clz = knn.classify(dataSet, dataLabel, vectorX, k) print 'distance: ' print distance print 'clz=%d' % clzdef dating(): '''test dating classify''' # load data set dataUtil = DataUtil() dataSet, dataLabel = dataUtil.file2DataSet('../../../datasets/knn/datingTestSet.txt') dataSet = dataUtil.norm(dataSet) # split training set and testing set ratio = 0.8 trainingSet, trainingLabel, testingSet, testingLabel = dataUtil.spitData(dataSet, dataLabel, ratio) testingSize = testingSet.shape[0] # training and testing knn = kNN() for k in range(1, 11): error = 0 for i in range(testingSize): distance, clz = knn.classify(trainingSet, trainingLabel, testingSet[i,], k) if clz != testingLabel[i]: error += 1 print '%d, %.2f' % (k, error*1.0/testingSize)def f2d(): '''test file2dataset''' dataUtil = DataUtil() dataSet, dataLabel = dataUtil.file2DataSet('../../../datasets/knn/datingTestSet.txt') print 'dataSet:' print dataSet print 'dataLabel:' print dataLabelif __name__ == '__main__': knn() #dating() #f2d()
阅读全文
0 0
- 机器学习实战之K-近邻算法
- 机器学习实战笔记 K近邻算法
- 《机器学习实战》之K-近邻算法
- 机器学习实战-k近邻算法
- 机器学习实战 k-近邻算法
- 【机器学习实战】-k近邻算法
- 《机器学习实战》—K-近邻算法
- 【机器学习实战一:K-近邻算法】
- 机器学习实战(k-近邻算法)
- 机器学习实战笔记:K近邻算法
- 机器学习实战笔记 k-近邻算法
- 机器学习实战之k-近邻算法
- 机器学习实战--k近邻算法
- 机器学习实战:K近邻算法(kNN)
- 【机器学习实战02】k-近邻算法
- 机器学习实战-K-近邻算法
- 机器学习实战之K近邻算法
- 【机器学习实战-python3】k-近邻算法
- Attention and Augmented Recurrent Neural Networks
- computed 计算属性无法双向绑定
- MySQL横表和纵表的相互转换
- 98道常见Hadoop面试题及答案解析
- GPIO介绍
- [机器学习实战] k-近邻算法
- 自己的maven jar 包发布到私服服务器不成功,问题解决方案
- Android进阶之使用Scheme实现从网页启动APP
- Redis系列-4.哈希(Hash)结构
- B树与B+树简明扼要的区别
- Java中接口及抽象类的实例化问题
- phpcms后台本地运行速度慢
- WHCTF 2017 逆向题 CRACKME、BABYRE、EASYHOOK 的解题思路
- 我从Angular 2转向Vue.js, 也没有选择React