机器学习入门算法及其java实现-Kmeans(K均值)算法

来源:互联网 发布:js jsonarray清空 编辑:程序博客网 时间:2024/05/21 14:08

1、算法基本原理:

  • 对于K个类别的数据选取K个质心
  • 距离第 个质心最近的点归为 类

2、算法具体步骤:
- 选取K个随机点,将其标注为K个类别
- 计算样本点到这K个随机点的距离,根据距离最近的第i个点将其分类i类
- 根据分类的结果,计算新的质心,质心计算公式如下:

xinew=1nj=1nxj

- 根据新的质心重复(1)(2)(3)步骤直到达到停止条件:
||xnewixoldi||<ξ,ξ

下列数据使用R软件生成,Java处理

R程序语言:X<-matrix(1:200,nrow=100,ncol=2)Y<-matrix(0,nrow=100,ncol=1)for (i in 1:100){   if (runif(1)<0.5){   X[i,1]=exp(runif(1))*1.3     X[i,2]=exp(runif(1))*3   Y[i]=0}   else{   X[i,1]=exp(runif(1))*3   X[i,2]=exp(runif(1))*3   Y[i]=1}}data<-cbind(Y,X)write.table(data,"C:/Users/CJH/Desktop/R程序运行/Kmeanstest.txt",row.names=FALSE,col.names=FALSE)#生成数据Eclipse程序语言:package kmeans;import java.io.*;import java.util.*;public class InputData{    public void loadData(double [][]x,double[]y,String trainfile)throws IOException{       File file = new File("C:\\Users\\CJH\\Desktop\\R程序运行",trainfile);       RandomAccessFile raf= new RandomAccessFile(file,"r");       StringTokenizer tokenizer;          int i=0,j=0;       while(true){           String line = raf.readLine();           if(line==null)break;           tokenizer= new StringTokenizer(line);           y[i]=Double.parseDouble(tokenizer.nextToken());           while(tokenizer.hasMoreTokens()){           x[i][j]=Double.parseDouble(tokenizer.nextToken());           j++;           }           j=0;i++;       }       raf.close();    }}//读入数据package kmeans;public class Kmean {    public static double[] classfy(double[] u0,double[] u1,double[][] X){        double[] y1=new double[100];        double[] u=new double[2];        for (int i=0;i<100;i++){        y1[i]=0;        u[0]=X[i][0];        u[1]=X[i][1];        if(euclid(minus(u,u0))>euclid(minus(u,u1))){            y1[i]=1;        }    }        for (int i=0;i<100;i++){            System.out.println(y1[i]+" ");        }        return y1;}    public static double[] newu(double[] y1,int k,double[][] X){        double[] u=new double[2];        double[] u1=new double[2];        double[] u2=new double[2];        double[] a=new double[2];        u1[0]=0;        u1[1]=0;        u2[0]=0;        u2[1]=0;        int t1=0;        int t2=0;        for (int i=0;i<100;i++){                if(k==1){                    if(y1[i]==1){                        a[0]=X[i][0];                        a[1]=X[i][1];                        u1=plus(u1,a);                        t1=t1+1;                 }             u=dividV(u1,t1);                }                if(k==0){                   if(y1[i]==0){                        a[0]=X[i][0];                        a[1]=X[i][1];                        u2=plus(u2,a);                        t2=t2+1;                }             u=dividV(u2,t2);              }        }            return u;    }    private static double[] dividV(double[] u, int t) {        double[] temp=new double[2];        for (int i=0;i<2;i++){            temp[i]=u[i]/t;        }        return temp;    }    private static double[] plus(double[] u, double[] ds) {        double[] temp=new double[2];        for (int i=0;i<2;i++){            temp[i]=u[i]+ds[i];        }        return temp;    }    private static double euclid(double[] x) {        double sum=0;        for (int i=0;i<2;i++){            sum=sum+Math.pow(x[i], 2);        }        double euclid=Math.sqrt(sum);        return euclid;    }    private static double[] minus(double[] x,double[] y) {        double[] temp=new double[2];        for (int i=0;i<2;i++){            temp[i]=x[i]-y[i];        }        return temp;//Kmeans程序package kmeans;import java.io.FileWriter;import java.io.IOException;public class Kmain {public static void main(String[]arg) throws IOException{    double[][] X=new double[100][2];    double[] Y=new double[100];    InputData d=new InputData();    try {        d.loadData(X, Y, "data.txt");    } catch (IOException e) {        // TODO 自动生成的 catch 块        e.printStackTrace();    }    for(int i=0;i<100;i++){        for (int j=0;j<2;j++){            System.out.print(X[i][j]+" ");        }        System.out.println(Y[i]);    }    int[] U0=new int[2];    U0[0]=(int)(Math.random()*X.length);    U0[1]=(int)(Math.random()*X.length);    while(U0[1]==U0[0]){        U0[1]=(int)(Math.random()*X.length);    }    System.out.println(X[U0[0]][0]+" "+X[U0[0]][1]+"\n"+" "+X[U0[1]][0]+" "+X[U0[1]][1]);    double[] u0=new double[2];    double[] u1=new double[2];    double[] u2=new double[2];    double[] u3=new double[2];    double[] y1=new double[100];    u0[0]=X[U0[0]][0];    u0[1]=X[U0[0]][1];    u1[0]=X[U0[1]][0];    u1[1]=X[U0[1]][1];    System.out.println(u0[0]+" "+u0[1]);    System.out.println(u1[0]+" "+u1[1]);    y1=Kmean.classfy(u0,u1,X);    u2=Kmean.newu(y1,0,X);    u3=Kmean.newu(y1,1,X);    for (int i=0;i<100;i++){        System.out.println(y1[i]+" ");    }    while(euclid(minus(u2,u0))+euclid(minus(u3,u1))>0.000001){        u0=u2;        u1=u3;        y1=Kmean.classfy(u0,u1,X);        u2=Kmean.newu(y1,0,X);        u3=Kmean.newu(y1, 1, X);    }    System.out.println(u2[0]+" "+u2[1]);    System.out.println(u3[0]+" "+u3[1]);    FileWriter fw=new FileWriter("C:\\Users\\CJH\\Desktop\\R程序运行\\class.txt");    for(int i=0;i<100;i++){        fw.write((int) y1[i]+" ");    }    fw.close();}private static double euclid(double[] x) {    double sum=0;    for (int i=0;i<2;i++){        sum=sum+Math.pow(x[i], 2);    }    double euclid=Math.sqrt(sum);    return euclid;}private static double[] minus(double[] x,double[] y) {    double[] temp=new double[2];    for (int i=0;i<2;i++){        temp[i]=x[i]-y[i];    }    return temp;  }}//主程序部分

这里写图片描述

原始数据

这里写图片描述

分类结果

如上图所示,点2和点3为质心点的位置。上图显示Kmeans对该组数据分类结果十分好,达到了百分之百。

原创粉丝点击