机器学习实战——k—均值聚类算法

来源:互联网 发布:linux怎么重启服务器 编辑:程序博客网 时间:2024/05/20 06:25
from numpy import *import matplotlibimport numpy as npimport matplotlib.pyplot as plt
#读数据,list类型def loadDataSet(filename):    dataMat=[]    fr=open(filename)    for line in fr.readlines():        curLine=line.strip().split('\t')        fltLine=list(map(float,curLine))        dataMat.append(fltLine)    return dataMat  #list类型,必须转换为mat
#计算距离def disEclud(vecA,vecB):    return sqrt(sum(power(vecA-vecB,2)))
#随机产生质心def randCent(dataSet,k):    n=shape(dataSet)[1]    centroids=mat(zeros((k,n)))    for j in range(n):        minJ=min(dataSet[:,j])        rangeJ=float(max(dataSet[:,j])-minJ)        centroids[:,j]=minJ+rangeJ*random.rand(k,1)    return centroids
#普通k均值聚类def kMeans(dataSet,k,distMeas=disEclud,createCent=randCent):    m=shape(dataSet)[0]    clusterAssment=mat(zeros((m,2))) #[索引,距离]    centroids=createCent(dataSet,k) #质心    clusterChanged=True    while clusterChanged:        clusterChanged=False        for i in range(m):            minDist=inf;minIndex=-1 #inf:正无穷            for j in range(k):                distJI=distMeas(centroids[j,:],dataSet[i,:])                if distJI<minDist:                    minDist=distJI;minIndex=j            if clusterAssment[i,0]!=minIndex:clusterChanged=True            clusterAssment[i,:]=minIndex,minDist**2        print(centroids)        for cent in range(k):            ptsInClust=dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]            centroids[cent,:]=mean(ptsInClust,axis=0)    return centroids,clusterAssment#二分法k均值聚类,可消除局部最优问题def biKmeans(dataSet,k,distMeas=disEclud):    m=shape(dataSet)[0]    clusterAssment=mat(zeros((m,2)))    centroid0=mean(dataSet,axis=0).tolist()[0]    centList=[centroid0]    for j in range(m):        clusterAssment[j,1]=distMeas(mat(centroid0),dataSet[j,:])**2    while(len(centList)<k):        lowestSSE=inf        for i in range(len(centList)):            ptsInCurrCluster=dataSet[nonzero(clusterAssment[:,0].A==i)[0],:]            centroidMat,splitClustAss=kMeans(ptsInCurrCluster,2,distMeas)            sseSplit=sum(splitClustAss[:,1])            sseNotSplit=sum(clusterAssment[nonzero(clusterAssment[:,0].A!=i)[0],1])            print('sseSplit,and notSplit:',sseSplit,sseNotSplit)            if sseNotSplit+sseSplit<lowestSSE:                bestCentToSplit=i                bestNewCents=centroidMat                bestClustAss=splitClustAss.copy()                lowestSSE=sseSplit+sseNotSplit        bestClustAss[nonzero(bestClustAss[:,0].A==1)[0],0]=len(centList)        bestClustAss[nonzero(bestClustAss[:,0].A==0)[0],0]=bestCentToSplit        print('the bestCentToSplit is:',bestCentToSplit)        print('the len of bestClustAss is:',len(bestClustAss))        centList[bestCentToSplit]=bestNewCents[0,:]        centList.append(bestNewCents[1,:])        clusterAssment[nonzero(clusterAssment[:,0].A==bestCentToSplit)[0],:]=bestClustAss    return mat(np.array(centList)),clusterAssment
#优化算法,随机根据数据自动产生k值聚类,无需给定k
def nonk(datSet,k=2):    minlen=True    x=datSet[:,0].max()-datSet[:,0].min()    y=datSet[:,1].max()-datSet[:,1].min()    distance=(x**2+y**2)/30    while minlen:        Disk = []        #cenmat,splitclu=kMeans(datSet,k)        cenmat, splitclu = biKmeans(datSet, k)        for i in range(k):            d=splitclu[nonzero(splitclu[:,0].A==i)[0],1].max()            Disk.append(d)        if max(Disk)>distance:            k+=1            minlen=True        else:minlen=False        print('k=%d,dis=%f' %(k,distance))    return cenmat,splitclu,k
#数据可视化,仅针对2维数据:
def pict():    #datMat=mat(loadDataSet('testSet.txt'))    datMat = mat(file2matrix('result.txt'))    datMat=datMat.getA()    #myce, clu = kMeans(datMat, k)   #myce,clu=biKmeans(datMat,k)    myce, clu,k = nonk(datMat,)    print(clu)    myce=myce.getA()    clu=clu.getA()    fig = plt.figure()    ax = fig.add_subplot(111)    for i in list(range(k)):        indx=where(clu[:,0]==i)        ax.scatter(datMat[indx, 0], datMat[indx, 1], 20 * (i+1))    ax.scatter(myce[:, 0], myce[:, 1],marker='+')    plt.show()    return myce,clu

随机产生a个二维数据,进行k均值聚类:
def Mypict(a):    #datMat = mat(loadDataSet('testSet.txt'))    #datMat = datMat.getA()    # myce, clu = kMeans(datMat, k)    # myce,clu=biKmeans(datMat,k)    #a=random.randint(0,199)    datMat=random.rand(a,2)    myce, clu, k = nonk(datMat )    #print(clu)    myce = myce.getA()    clu = clu.getA()    fig = plt.figure()    ax = fig.add_subplot(111)    for i in list(range(k)):        indx = where(clu[:, 0] == i)        ax.scatter(datMat[indx, 0], datMat[indx, 1], 20 * (i + 1))    ax.scatter(myce[:, 0], myce[:, 1], marker='+')    plt.show()    print('k=%d,a=%d' %(k,a))

 
阅读全文
0 0
原创粉丝点击