机器学习实战——KNN及部分函数注解

来源:互联网 发布:淘宝店铺保证金在哪交 编辑:程序博客网 时间:2024/06/16 02:33


书籍:《机器学习实战》中文版
IDE:PyCharm Edu 4.02
环境:Adaconda3  python3.6

本系列主要是代码学习记录,其中设计的理论知识,不做过多解释。书中代码使用的是python2,所以代码会有些许变化,并对其中部分函数进行注解。


#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib.pyplot as pltfrom numpy import *import operatorfrom os import listdir# 例子一:KNN算法def createDataSet():    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])    labels = ['A','A','B','B']    return group,labelsgroup,labels = createDataSet()def classify0(inX,dataSet,labels,k):    # 计算inX与训练集之间的距离,并排序    dataSetSize = dataSet.shape[0]  #行数    diffMat = tile(inX,(dataSetSize,1))-dataSet    sqDiffMat = diffMat**2    sqDistance = sqDiffMat.sum(axis=1)    distances = sqDistance**0.5    sortedDistIndicies = distances.argsort()  #返回索引值    classCount = {}    #对前K个的标签进行统计    for i in range(k):        votelLabel = labels[sortedDistIndicies[i]]        classCount[votelLabel] = classCount.get(votelLabel,0)+1    # 对统计的标签数量进行降序排序    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)    #print(sortedClassCount)    return sortedClassCount[0][0]# test = classify0([0,0],group,labels,3)# print(test)#例子二:约会数据# 定义文本读取并转换为矩阵格式的函数def file2matrix(filename):    # 得到文件行数,构造矩阵    fr = open(filename)    arrayLines = fr.readlines()    numberOfLines = len(arrayLines)    returnMat = zeros((numberOfLines,3))    labelVector = []    index = 0    for line in arrayLines:        line = line.strip()   #去掉每行头尾空白        # print(line)        listFromLine = line.split('\t')  #获取列表元素        returnMat[index,:] = listFromLine[0:3]        labelVector.append(int(listFromLine[-1]))        index += 1    return returnMat,labelVector#可视化数据# data,label = file2matrix('datingTestSet2.txt')# fig = plt.figure()# ax = fig.add_subplot(111)# ax.scatter(data[:,0],data[:,1],c=label)# plt.show()# 数据归一化def autoNorm(dataSet):    minVals = dataSet.min(0)   #参数0表示求每列的最小值;    maxVals = dataSet.max(0)    ranges = maxVals-minVals    normDataSet = zeros(shape(dataSet))    m = dataSet.shape[0]    normDataSet = dataSet-tile(minVals,(m,1))    normDataSet = normDataSet/tile(ranges,(m,1))    return normDataSet,ranges,minVals# 分类器评估def datingClassTest():    ratio = 0.1    datingData,datingLabels = file2matrix('datingTestSet2.txt')    normMat,ranges,minVals = autoNorm(datingData)    m = normMat.shape[0]    numTest = int(m*ratio)    errorCount = 0.0    for i in range(numTest):        result = classify0(normMat[i,:],normMat[numTest:m,:],datingLabels[numTest:m],3)        print("the classifier's result: %d,the real answer: %d"\              % (result,datingLabels[i]))        if (result!=datingLabels[i]):            errorCount += 1.0        print("the total error rate is : %f" % (errorCount/float(numTest)))#print(datingClassTest())# 用户使用程序段def classifyPerson():    resultList = ['not at all','in small doses','in large doses']    percentTats = float(input("玩游戏和看视频花费的时间比率?"))    ffMiles = float(input("飞行里程数?"))    iceCream = float(input("每周消费的冰淇淋公升数?"))    datingData,datingLabels = file2matrix('datingTestSet2.txt')    normMat,ranges,minVals = autoNorm(datingData)    inArr = array([ffMiles,percentTats,iceCream])    classierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)    print("You will probably like this person:",resultList[classierResult-1])#print(classifyPerson())# 例子三:手写识别系统def img2vector(filename):    returnVect = zeros((1,1024))    fr = open(filename)    for i in range(32):        lineStr = fr.readline()        for j in range(32):            returnVect[0,32*i+j] = int(lineStr[j])    return returnVecttest = img2vector('trainingDigits/0_13.txt')def handwritingClassTest():    hwLabels = []    # 获取目录内容    trainFileList = listdir('trainingDigits')    #获取目录    m = len(trainFileList)    trainMat = zeros((m,1024))    #从文件名中解析分类数字    for i in range(m):        fileNameStr = trainFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])        hwLabels.append(classNumStr)        trainMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)    testFileList = listdir('testDigits')    errorCount = 0.0    mTst = len(testFileList)    for i in range(mTst):        fileNameStr = testFileList[i]        fileStr = fileNameStr.split('.')[0]        classNumStr = int(fileStr.split('_')[0])        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)        classierResult = classify0(vectorUnderTest,trainMat,hwLabels,3)        print("分类器结果:%s,实际结果:%s" % (classierResult,classNumStr))        if (classierResult!=classNumStr):            errorCount +=1.0    print("错误总数:%d" % errorCount)    print("错误率:%f" % (errorCount/float(mTst)))print(handwritingClassTest())

部分函数注释:
1、二维矩阵函数参数:0代表按列操作;1代表按行操作。
2、numpy shape[0]:矩阵第一维度的长度。
3、numpy tile(A,reps):根据reps设定的形式,进行重复。比如tile(A,(4,1)) ,结果是4行1列的,其中每一个元素都是A。
4、numpy sum(axis=1):可以设置参数axis。0表示按列相加,1表示按行相加。
5、numpy argsort():排序,返回数组值从小到大的索引值。
6、字典对象的get(key,default=None):返回指定键的值,若给键不存在则返回默认值。
7、内置函数 sorted():
sort 与 sorted 区别:
sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
list 的 sort 方法返回的是对已经存在的列表进行操作,而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。
sorted(iterable[, cmp[, key[, reverse]]])
参数说明:
iterable -- 可迭代对象。
cmp -- 比较的函数,这个具有两个参数,参数的值都是从可迭代对象中取出,此函数必须遵守的规则为,大于则返回1,小于则返回-1,等于则返回0。
key -- 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。key接受一个函数。
reverse -- 排序规则,reverse = True 降序 , reverse = False 升序(默认)。
operator.itemgetter(1):operator.itemgetter函数获取的不是值,而是定义了一个函数,通过该函数作用到对象上才能获取值。
                        用于指定获取对象的哪些维的数据。这里指对字典中的value进行排序。


8、readlines():一次读取全部文本,按行返回,结果为一个列表。
                          一般用for...in..语句获取每一行内容,且常进行操作
                         line = line.strip()   #去掉每行头尾空白
                         listFromLine = line.split('\t')  #获取列表元素
9、matplotlib.pyplot scatter():                                         
具体见 http://blog.csdn.net/anneqiqi/article/details/64125186
scatter(x,y,s=,c=):
x,y相同长度的数组数据;
s 标量或数组,可选,默认20,散点图每个点的大小
c 色彩序列,可选。

10、numpy min():
import numpy as np  
a = np.array([[1,5,3],[4,2,6]])  
print(a.min()) #无参,所有中的最小值  
print(a.min(0)) # axis=0; 每列的最小值  
print(a.min(1)) # axis=1;每行的最小值

11、细节问题
for...in.. 中:不用忘记列表生成方法range()
numpy数组对象 array()
图片显示:plt.show()





原创粉丝点击