无监督学习第一步:K-均值聚类

来源:互联网 发布:会计记账软件免费版 编辑:程序博客网 时间:2024/06/11 00:05

简介

k均值聚类是无监督学习方法里面的一种,所谓聚类(clustering),就是将相似的对象归到一个簇中。簇中的对象越相似,聚类的效果越好。K均值聚类就是在数据集中发现k个簇的算法。

算法描述

k-means算法如下:

创建k个随机质心
当质心改变时
对每个点
对每个质心
计算点到质心的距离
将点归到到距离最近的质心
重新计算质心

很简单的一个算法。其中距离的计算有很多方式,我习惯选取最简单的欧几里得距离。注意k-means算法在大型数据集上收敛较慢。

对聚类效果的评估

聚类的效果用SSE(sum of squared error,误差平方和)来进行评估。SSE指得是数据集上所有数据点到其被分配的质心的距离的平方和。SSE越小,表明数据点距离质心越近,说明聚类效果越好。

2分k均值聚类

k-means聚类效果比较差的原因是算法会收敛于局部最小值而不是全局最小值,为了克服这个问题,有人提出2分k均值聚类的方法。算法将一个SSE大的簇不断划分为2个簇,知道得到k的簇为止,伪代码如下:

将所有点看做一个簇
当簇的个数小于k时
对于每个簇
计算总误差
将其利用k-means分为2个簇
计算新总误差
选择使新总误差最小的簇进行划分
还有一种做法是选择使得SSE降低最多的划分作为迭代,过程和上面差不多。

最后直接上代码

from numpy import *;from matplotlib.pyplot import *;def loadData(filename):    dataMat=[];    fr = open(filename);    for line in fr.readlines():        cur = line.strip().split('\t');        #python3 map返回iterators,需转换为list        flt = list(map(float,cur));        dataMat.append(flt);    return dataMat;def dist(a,b):    return sqrt(sum(power(a-b,2)));def randCent(dataSet,k):    n = shape(dataSet)[1];    cent = mat(zeros((k,n)));    for j in range (n):        minJ = min(dataSet[:,j]);        #print(dataSet);        rangeJ = float(max(dataSet[:,j])-minJ);        cent[:,j] = minJ +rangeJ*random.rand(k,1);    return cent;def plotkMeans(dataSet,cent):    scatter(dataSet[:,0],dataSet[:,1]);    plot(cent[:,0],cent[:,1],'r+',);    show();def kMeans(dataSet,k):    m=shape(dataSet)[0];    clusterAss = mat(zeros((m,2)));    cent =  randCent(dataSet,k);    clusterChanged = True;    while(clusterChanged):        clusterChanged=False;        for i in range(m):            mindist = inf;minIndex=-1;              for j in range(k):                distJ = dist(cent[j,:],dataSet[i,:]);                if(distJ < mindist):                    mindist = distJ;                    minIndex = j;            if (clusterAss[i,0]!=minIndex):                clusterChanged=True;            clusterAss[i,:]= minIndex,mindist**2        for tcent in range(k):            ptsIncluts = dataSet[nonzero(clusterAss[:,0].A==tcent)[0]];            cent[tcent,:] = mean(ptsIncluts,axis=0);    return cent,clusterAss;def bikMeans(dataSet,k):    m=shape(dataSet)[0];    clusterAss = mat(zeros((m,2)));    cent =  mean(dataSet,axis=0).tolist()[0];    centlist=[cent];    for j in range(m):        clusterAss[j,1]=dist(mat(cent),dataSet[j,:])**2;    while(len(centlist)<k):        minsse = inf;        for i in range(len(centlist)):            ptsIncluts = dataSet[nonzero(clusterAss[:,0].A==i)[0],:];            tcent,sselist = kMeans(ptsIncluts,2);            seesplit = sum(sselist[:,1]);            seenotsplit=sum(clusterAss[nonzero(clusterAss[:,0].A!=i)[0],1]);            if((seesplit+seenotsplit)<minsse):                minsse =seesplit+seenotsplit;                index=i;                splitcent=tcent;                splitAss = sselist.copy();        splitAss[nonzero(splitAss[:,0].A==1)[0],0]=len(centlist);        splitAss[nonzero(splitAss[:,0].A==0)[0],0]=index;        print ('the bestCentToSplit is: ',i)        print ('the len of bestClustAss is: ', len(splitAss))        centlist[index] = splitcent[0,:].tolist()[0];        centlist.append(splitcent[1,:].tolist()[0]);        clusterAss[nonzero(clusterAss[:,0].A == index)[0],:]= splitAss;        #print(centlist);    return centlist,clusterAss;if __name__ =='__main__':    dataSet = loadData("testSet2.txt");    dataSet = mat(dataSet);    cent,sse = bikMeans(dataSet,3);    print(mat(cent));    plotkMeans(dataSet,mat(cent));

完整代码地址:github

0 0
原创粉丝点击