机器学习经典算法9-k-means

来源:互联网 发布:网络学校在线教育 编辑:程序博客网 时间:2024/05/01 09:50

1.基本介绍

  knn是k邻居分类算法,该算法基于训练集中已给出类别的样本数据,来对测试集中的数据进行分类;而实际中可能碰到不知道确定分类的情况,这里需要k-means来解决这种unsupervised learning。

2.基本算法

          kmeans的基本流程如下:
                1.初始化:k个中心节点作为类别初始质心
                2.对于数据集中的每个数据点
                        对于k个质心中的每个质心
                                计算数据点到质心的距离
                       选取和数据点最近的质心,并将其归类
               3.判断2过程中是否有数据点的分类发生改变,如果发生改变,则执行4,否则退出
               4.根据最新分类情况,对k个中心质心特征进行更新,并转到2。
         说明:在判断算法是否收敛时,这里是通过检测是否有数据点的分类发生改变,也就是说一直要运行到每个数据点的分类不变;其实还有一种做法是检测整体的数据点到其质心的距离平方和的变化

3.算法的改进

         原始的k-means对k的设定以及质心特征或坐标的设定十分敏感,很容易陷入局部最优的困境中。为了提高聚类的性能,这里介绍二分k-means。
         二分k-means:
        1.初始化。将所有点看做一个簇,该簇的质心设定为各个特征值的平均,计算各个点到该质心的距离
        2.当质心数目小于k时
                2.1对于每个分类
                      2.1.1 对该分类进行二分的kmeans
                      2.1.2计算二分之后的误差情况+分本分类的误差=二分后总误差
                2.2选取二分后总误差最小的分类,生成新的分类结果

4算法示例

这里randCent进行质点的初始化,使用随机方法;distEclud计算两个点之间的距离。
from numpy import *import matplotlib.pyplot as pltdef loadDataSet(filename):    dataMat = []    fr = open(filename)    for line in fr.readlines():        curLine=line.strip('\n').split('\t')        fltLine=map(float,curLine)        dataMat.append(fltLine)    fr.close()    return dataMatdef distEclud(vecA,vecB):    return sqrt(sum(power(vecA-vecB,2)))def randCent(dataSet,k):    n=shape(dataSet)[1]    centroid=mat(zeros((k,n)))    for j in range(n):        minJ = min(dataSet[:,j])        rangeJ = float(max(dataSet[:,j])-minJ)        centroid[:,j]=minJ+rangeJ*random.rand(k,1)    return centroiddef kMeans(dataSet, k, disMeas=distEclud, createCent=randCent):    m=shape(dataSet)[0]    clusterAssment=mat(zeros((m,2)))    centroids=createCent(dataSet,k)    clusterChanged=True    while clusterChanged:        clusterChanged=False        for i in range(m):            minDist=inf            minIndex=-1            for j in range(k):                distJI = disMeas(centroids[j,:],dataSet[i,:])                if distJI < minDist:                    minDist=distJI                    minIndex=j            if clusterAssment[i,0]!=minIndex:                clusterChanged=True                clusterAssment[i,:]=minIndex,minDist**2        for cent in range(k):            ptsInClust = dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]            centroids[cent,:]=mean(ptsInClust,axis=0)    return centroids, clusterAssmentdef biKmeans(dataSet, k, distMeas=distEclud):    m=shape(dataSet)[0]    clusterAssment=mat(zeros((m,2)))    centroid0=mean(dataSet, axis=0).tolist()[0]    centList=[centroid0]    for j in range(m):        clusterAssment[j,1]=distMeas(mat(centroid0), dataSet[j,:])**2    while(len(centList)<k):        lowestSSE= inf        for i in range(len(centList)):            ptsInCurrCluster=\                               dataSet[nonzero(clusterAssment[:,0].A==i)[0],:]            centroidMat,splitClustAss=kMeans(ptsInCurrCluster,2,distMeas)            sseSplit=sum(splitClustAss[:,1])            sseNotSplit=sum(clusterAssment[nonzero(clusterAssment[:,0].A!=i)[0],1])            if(sseSplit+sseNotSplit)<lowestSSE:                bestCentToSplit=i                bestNewCents=centroidMat                bestClustAss=splitClustAss.copy()                lowestSSE=sseSplit+sseNotSplit        bestClustAss[nonzero(bestClustAss[:,0].A==1)[0],0]=len(centList)        bestClustAss[nonzero(bestClustAss[:,0].A==0)[0],0]=bestCentToSplit        print "the bestCentToSplit is: ", bestCentToSplit        print 'the len of bestClustAss is: ', len(bestClustAss)        centList[bestCentToSplit]=bestNewCents[0,:]        centList.append(bestNewCents[1,:])        clusterAssment[nonzero(clusterAssment[:,0].A==bestCentToSplit)[0],:]=bestClustAss    return centList, clusterAssmentdef plotCluRe(dataSet, centroid, clusterAss):    if len(centroid)==3:        x1=dataSet[nonzero(clusterAss[:,0].A==0)[0],0]        y1=dataSet[nonzero(clusterAss[:,0].A==0)[0],1]        x2=dataSet[nonzero(clusterAss[:,0].A==1)[0],0]        y2=dataSet[nonzero(clusterAss[:,0].A==1)[0],1]        x3=dataSet[nonzero(clusterAss[:,0].A==2)[0],0]        y3=dataSet[nonzero(clusterAss[:,0].A==2)[0],1]        cen1x,cen1y=centroid[0][0,0],centroid[0][0,1]        cen2x,cen2y=centroid[1][0,0],centroid[1][0,1]        cen3x,cen3y=centroid[2][0,0],centroid[2][0,1]        plt.plot(x1,y1,'bo',x2,y2,'r+',x3,y3,'ys')        plt.plot(cen1x,cen1y,'b+',cen2x,cen2y,'r+',cen3x,cen3y,'y+', markersize=20)        plt.show()datMat=mat(loadDataSet(r"testSet.txt"))k=4myCentroids, clusterAss=kMeans(datMat,k)for i in range(1,k+1):    print "Centroid "+str(i)+": "+str(myCentroids[i-1,:])    print "Numbers: "+str(len(nonzero(clusterAss[:,0].A==(i-1))[0]))datMat2=mat(loadDataSet(r"testSet2.txt"))cenList, myNewAss=biKmeans(datMat2,3)print cenListplotCluRe(datMat2, cenList, myNewAss)


             
原创粉丝点击