Kmeans 图像分割 by python

来源:互联网 发布:淘宝店铺监控插件 编辑:程序博客网 时间:2024/05/17 12:55

依然是神奇的numpy boardcast!!!
于是代码只有28行!!!

>
输入:k,data[n];
(1) 选择 k 个初始中心点,例如 c[0] = data[0], ……,c[k-1]=data[k-1];
(2) 对于 data[0],……,data[n],分别与 c[0],……,c[k-1]比较,假设 c[i]差值最少,就标记为 i;
(3) 对于所有标记为 i 点,重新计算 c[i]={所有标记为 i 的 data[i]z 之和}/标记为 i 的个数;
(4) 重复(2)(3),直至所有 c[j]值的变化小于给定阈值。

from scipy.misc import imread,imshow,imsaveimport numpy as npfrom functools import partialdef kmeans(img,K,epsilon):    img = img.astype(np.float64)    randpos = partial(np.random.randint,0,min(img.shape[0],img.shape[1]))    cx,cy = [randpos(K) for i in range(2)]    center = img[cx,cy]    img = img.reshape(1, img.shape[0], img.shape[1], -1)    center = center.reshape(K, 1, 1, 3)    # ite = 0    diff = np.inf    pre_center = np.sum(center)    while(diff>epsilon):        dis = (img - center) ** 2        pos_label = np.sum(dis, axis=3).argmin(axis=0)        for i in range(K): center[i] = np.mean(img[0,pos_label == i],axis=0)        diff = np.abs(np.sum(center)-pre_center)        pre_center = np.sum(center)        # ite+=1        # print(ite,diff)    for i in range(K): img[0,pos_label == i] = center[i]    return np.squeeze(img).astype(np.float16)if __name__ == '__main__':    img = np.floor(imread("/home/ryan/Desktop/cat.jpg"))    img = kmeans(img,5,0.05)    imshow(img)
0 0
原创粉丝点击