K-means算法的Python实现

来源:互联网 发布:知世故而不世故 下联 编辑:程序博客网 时间:2024/05/18 00:15

简介

通过本文你可以了解到
- K-means算法的基本思想
- 利用Python来实现K-means算法
- 利用Python进行简单的绘图

准备

为了顺利完成该程序,需要配置
- Python3
- Numpy
- matplotlib.pyplot(绘图库)
- sklearn(数据集来源)

K-means算法详解

简介

K-means算法是一种应用于数据挖掘领域的聚类算法,其目的是将N个元素通过聚类得到K个类,不需要制定类的具体含义,只需要设定K的个数即可。

具体实现

初始化

根据K的个数随机的选择几个点作为中心点

迭代

以下两步交替迭代执行
1. 分配:通过已经确定的几个中心点,对所有点计算到该点的距离,并将其划分到最近的点的集合里面。
2. 更新:对于第一步已经得到的几个集合,重新计算他们的中心点。

完成条件

如果我们发现两代之间的中心点差距小于一定值,那么就认为聚类完成,对于小数据我们可以认为当两代中心相同时聚类便已经完成。

Python代码实现

 数据来源

我们采用sklearn中提供的鸢尾花数据库的前两维度来进行测试,通过如下方式加载即可。

iris_data=sklearn.datasets.load_iris()data=iris_data.data[:,1:3]

K-means算法的初始化

由上面的介绍可知,k-means算法的参数较为简单,只需要一个k值来指定有几个类即可。

def __init__(self,data,k):    self.data=data    self.k=k

初始化中心

我们采用随机的方法选择几个点作为初始的中心点

def init_center(self):    m=np.shape(self.data)[0]    n=np.shape(self.data)[1]    list=[]    for i in range(self.k):        temp=random.randint(0,m)        list.append(self.data[temp])    return list

计算距离

计算距离的方法有很多,在本次实验中我选用了欧氏距离

def cal_distance(self,p1,p2):    result=np.sqrt(np.sum((p2-p1)**2))    return result

迭代的框架

迭代的结构很简单,分配节点-计算中心-判断与上次中心差异即可

def K_Means_cal(self):    center_list=self.init_center()    while True:        assort_result=self.assort_node(center_list)        this_center_list=self.cal_center(center_list,assort_result)        if self.compare_center(this_center_list,center_list):            return center_list        else:            center_list=this_center_list

分配节点

遍历整个数据组,对每一个节点找到与其最近的节点,加入到List中即可

def assort_node(self,center_list):    templist = []    for i in range(np.shape(self.data)[0]):        mindis = None        tempcenter = None        for j in center_list:            tempdis = self.cal_distance(self.data[i], j)            if mindis == None:                mindis = tempdis                tempcenter = j            elif mindis > tempdis:                mindis = tempdis                tempcenter = j        templist.append((tempcenter, data[i]))    return templist

计算中心

对于已经归类的节点,对每一类计算中心并汇总即可。中心点(x,y)的计算公式为

x=(x1+x2++xn)n,y=(y1+y2++yn)n

def cal_center(self,center_list,assort_result):    temp_center=[]    for i in center_list:        x=0        y=0        count=0        for j in assort_result:            if str(i)==str(j[0]):                x+=j[1][0]                y+=j[1][1]                count+=1        x=x/count        y=y/count        temp_center.append((x,y))    return temp_center

判断是否已经完成

判断两个list是否相同即可

def compare_center(self,list1,list2):    bool=False    for i in list1:        bool=False        for j in list2:            if str(i)==str(j):                bool=True                break    return bool

可视化呈现

对于数据的可视化呈现,我们采用了python的matplotlib.pyplot库来进行绘图,通过方法(此处已经import matplotlib.pyplot as plt)

plt.scatter(x,y,color)

来绘制散点图,其中x,y是坐标,既可以是一个也可以是一组,color是颜色也可以对于每一个点单独指定颜色来传入一个列表或者是同意颜色,例如

plt.scatter(self.data[:,0],self.data[:,1],c='g')# 或者plt.scatter(i[0],i[1],c='b')

然后采用

plt.show()

来将图展示出来。

结语

为了完成数据挖掘老师留下的实验作业为了更好地掌握数据挖掘课上所讲授的算法,才有了这篇文章,python庞大的第三方库的支持对于数据挖掘来说确实减少了很多的工作量,本以为比较复杂的代码实际写下了还是比较清晰的,plt所提供的数据可视化也对我判断结果好坏有了很好的帮助。之后还有一篇关于Apriori算法的实验代码已经完成了,过两天应该也会写了传上来吧,马上要考试了啊……

0 0
原创粉丝点击