Kmeans原理及实现

来源:互联网 发布:淘宝上误点了确认收货 编辑:程序博客网 时间:2024/06/08 00:34

在聚类分析中,最简单、基本的版本是划分,它把对象组织成多个互斥的组或簇。
这些簇的形成旨在优化一个客观划分准则,使得同一簇中的对象是相似的,不同簇的对象是相异的。
最常用的划分方法包括,k-means和k-medoids。

一. k-means算法

k-means算法是一种很常见的聚类算法,它的基本思想是:通过迭代寻找k个聚类的一种划分方案,使得用这k个聚类的均值来代表相应各类样本时所得的总体误差最小。

算法

算法: k-均值。用于划分的k-均值算法,其中每个簇的中心都用簇中所有对象的均值来表示。输入:   k: 簇的数目   D: 包含n个对象的数据集输出:    k个簇的集合方法:    1、随机选取 k个聚类质心点    2、重复下面过程直到收敛  {      对于每一个样例 i,计算其应该属于的类:      对于每一个类 j,重新计算该类的质心:    3.停止条件:        1.超过最大迭代        2.J超过阈值}

二.实现

from __future__ import with_statementimport randomimport numpy as npfrom scipy.linalg import  normimport numpy.matlib as mlimport cPickle as picklefrom matplotlib import pyplotfrom numpy import zeros, array, tiledef Kmeans(X, k, observer=None, threshold=1e-15, maxiter=300):    N = len(X)    labels = np.zeros(N,dtype=int)    centers = np.array(random.sample(X,k))    iter = 0    def calc_J():        sum = 0        for i in xrange(N):            sum += norm(X[i] - centers[labels[i]])        return  sum    def  distmat(X,Y):        n = len(X)        m = len(Y)        xx = ml.sum(X*X,axis=1) # #axis=1 是按行求和        print "xx:{}".format(xx)        yy = ml.sum(Y*Y,axis=1)        print "yy:{}".format(yy)        xy = ml.dot(X,Y.T) #dot矩阵相乘        return tile(xx,(m,1)).T + tile(yy,(n,1)) - 2*xy #tile矩阵复制    Jprev = calc_J()    while True:        # notify the observer        if observer is not None:            observer(iter, labels, centers)         # calculate distance from x to each center         # distance_matrix is only available in scipy newer than 0.7         # dist = distance_matrix(X, centers)        dist = distmat(X, centers)         # assign x to nearst center        labels = dist.argmin(axis=1)         # re-calculate each center        for j in range(k):            idx_j = (labels == j).nonzero()            centers[j] = X[idx_j].mean(axis=0)        J = calc_J()        iter += 1        if Jprev-J < threshold:            break        Jprev = J        if iter >= maxiter:            break     # final notification    if observer is not None:        observer(iter, labels, centers)if __name__ == '__main__':     # load previously generated points    with open('cluster.pkl') as inf:        samples = pickle.load(inf)    N = 0    for smp in samples:        N += len(smp[0])    X = zeros((N, 2))    idxfrm = 0    for i in range(len(samples)):        idxto = idxfrm + len(samples[i][0])        X[idxfrm:idxto, 0] = samples[i][0]        X[idxfrm:idxto, 1] = samples[i][1]        idxfrm = idxto    def observer(iter, labels, centers):        print "iter %d." % iter        colors = array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])        pyplot.plot(hold=False)  # clear previous plot        pyplot.hold(True)          # draw points        data_colors=[colors[lbl] for lbl in labels]        pyplot.scatter(X[:, 0], X[:, 1], c=data_colors, alpha=0.5)         # draw centers        pyplot.scatter(centers[:, 0], centers[:, 1], s=200, c=colors)        pyplot.savefig('kmeans/iter_%02d.png' % iter, format='png')    Kmeans(X, 3, observer=observer)

三.结果

这里写图片描述

0 0