FuzzyKmeans聚类JAVA版本实现

来源:互联网 发布:细说php视频 百度云 编辑:程序博客网 时间:2024/05/05 18:45

在对数据进行聚类时,最常用的方法应该是kmeans,但是kmean只能保证每一条待聚类的数据划分到一个类别,针对一条数据可以被划分到多个类别的情况无法处理。为此,人们提出了FuzzyKmeans聚类方法,该方法衡量的是每一条数据属于某个类别的概率,既然是概率就不再是非1即0的情况,这样就能保证一条数据可以被划分到多个类别。

对应FuzzyKmeans的聚类过程如下:



其中dij这个参数衡量的是该条数据i到类别j中心点的距离,uij就是数据i属于类别j的概率。

求得概率之后,需要更新某个类别的中心点,这时就按照(4)式更新,也就是用属于该类的概率与数据原先的值加以计算

至于结束条件一种是达到设定的迭代次数,一种是满足第四步的条件,即两个类别的中心点距离小于一个值。

最重要的应该是m值的选择,当每条数据距离各个类别中心点距离比较接近时,建议1/(m-1)值较大,因为这样在指数运算后距离就能有较大差异了,此时m接近于1. 如果距离本来就有很大差异,1/(m-1)就可以取值小一些,一般来说m取1.5,这样就足够了。

最后要注意迭代次数不宜过多,一般两次足够,因为考虑的是概率,如果迭代次数过多,中心点偏移较大,很可能得到数据到各个类别的概率都相差不大。


下面用JAVA实现的FuzzyKmeans,每一条数据都是一个200维的向量,使用时可以指定初始中心点,中心点的向量需要从待聚类数据中查找得到。

首先是处理输入的类:

package kmeans;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;import java.util.HashMap;import java.util.List;public class Word2VEC {    private HashMap<String, double[]> wordMap = new HashMap<String, double[]>();    public void loadVectorFile(String path) throws IOException {        BufferedReader br = null;        double len = 0;        double vector = 0;        int size=0;        try {        File f = new File(path);        br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));            String word;            String line="";            String[] outline=new String[210];            double[] vectors = null;            int count=0;            while((line=br.readLine())!=null){            if(count%100000==0){            System.out.println("read: "+count);            }            count++;            outline=line.split(",");            size=outline.length-1;                word = outline[0];                vectors = new double[size];                len = 0;                for (int j = 0; j < size; j++) {                    vector = Float.parseFloat(outline[j+1]);                    len += vector * vector;                    vectors[j] = (double) vector;                }                len = Math.sqrt(len);                for (int j = 0; j < size; j++) {                    vectors[j] /= len;                }                wordMap.put(word, vectors);            }        }         finally {        System.out.println("total word: "+wordMap.size()+" vector dimensions: "+size);        br.close();        }    }    public HashMap<String, double[]> getWordMap() {        return wordMap;    }        //calculate how many center point in the samples    public List<String> loadPointFile(String point_path) throws IOException{    File f = new File(point_path);BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));String line="";List<String> center=new ArrayList<String>();while((line=br.readLine())!=null){if(wordMap.containsKey(line)){center.add(line);}}br.close();return center;    }}

然后是聚类的类:

package kmeans;import java.io.BufferedReader;import java.io.BufferedWriter;import java.io.File;import java.io.FileInputStream;import java.io.FileNotFoundException;import java.io.FileOutputStream;import java.io.IOException;import java.io.InputStreamReader;import java.io.OutputStreamWriter;import java.util.ArrayList;import java.util.Collections;import java.util.HashMap;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Map.Entry;import org.apache.commons.cli.CommandLine;import org.apache.commons.cli.CommandLineParser;import org.apache.commons.cli.HelpFormatter;import org.apache.commons.cli.Options;import org.apache.commons.cli.ParseException;import org.apache.commons.cli.PosixParser;public class FuzzyKmeans { private HashMap<String, double[]> wordMap = null;    private int iter;    private Classes[] cArray = null;    public static HashMap<Integer,String> wordcenter=new HashMap<Integer,String>();    //total 659624 words each is a 200 vector//args[0] is the word vectors csv file//args[1] is the output file //args[2] is the cluster number//args[3] is the iterator number    public static void main(String[] args) throws IOException, ParseException {        String source_path;    String output_path;    int cluster_num = 10;    int iterator_num = 10;    double m=1.5;        String point_path = null;         Options options = new Options();           options.addOption("h", false, "help"); //参数不可用         options.addOption("i", true, "input file path"); //参数可用              options.addOption("o", true, "output file path"); //参数可用          options.addOption("c", true, "cluster number, default 10"); //参数可用          options.addOption("x", true, "iterator number, default 10"); //参数可用          options.addOption("p", true, "the center point"); //参数可用         options.addOption("m", true, "the parameter for fuzzy kmeans"); //参数可用                  CommandLineParser parser = new PosixParser();           CommandLine cmd = parser.parse(options, args);              if (cmd.hasOption("i"))           {           source_path = cmd.getOptionValue("i");           }else{         HelpFormatter formatter = new HelpFormatter();               formatter.printHelp( "help", options );              return;         }                  if (cmd.hasOption("o"))           {           output_path = cmd.getOptionValue("o");           }else{         HelpFormatter formatter = new HelpFormatter();               formatter.printHelp( "help", options );              return;         }            if (cmd.hasOption("c"))           {           cluster_num = Integer.parseInt(cmd.getOptionValue("c"));           }         if (cmd.hasOption("m"))           {           m = Double.parseDouble(cmd.getOptionValue("m"));           }                  if (cmd.hasOption("x"))           {           iterator_num = Integer.parseInt(cmd.getOptionValue("x"));           }         if (cmd.hasOption("p"))           {           point_path = cmd.getOptionValue("p");           }                  if (cmd.hasOption("h"))           {               HelpFormatter formatter = new HelpFormatter();               formatter.printHelp( "help", options );          }                 Word2VEC vec = new Word2VEC();        vec.loadVectorFile(source_path);        System.out.println("load data ok!");                        List<String> center=new ArrayList<String>();        if(point_path!=null){        center=vec.loadPointFile(point_path);        if(cluster_num<center.size()){        cluster_num=center.size();        }                }                        FuzzyKmeans fuzzyKmeans = new FuzzyKmeans(vec.getWordMap(), cluster_num,iterator_num);        Classes[] explain = fuzzyKmeans.explain(point_path,m,center);                File fw = new File(output_path);        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fw), "UTF-8"));            for (int i = 0; i < explain.length; i++) {            List<Entry<String, Double>> result=explain[i].getMember();            StringBuffer buf = new StringBuffer();            for (int j = 0; j < result.size(); j++) {            buf.append(i+"\t"+wordcenter.get(i)+"\t"+result.get(j).getKey()+"\t"+String.format("%.6f", result.get(j).getValue())+"\n");            }            bw.write(buf.toString());            bw.flush();        }        bw.close();                for(int i=0;i<wordcenter.size();i++){        System.out.println(i+"\t"+wordcenter.get(i));        }    }    public FuzzyKmeans(HashMap<String, double[]> wordMap, int clcn, int iter) {        this.wordMap = wordMap;        this.iter = iter;        cArray = new Classes[clcn];    }    public Classes[] explain(String point_path,double m,List<String> center) throws IOException, FileNotFoundException {    Iterator<Entry<String, double[]>> iterator = wordMap.entrySet().iterator();    //cluster number is the same as the center point number    if(cArray.length==center.size()){    String word="";    for (int i = 0; i < cArray.length; i++) {        word=center.get(i);        cArray[i] = new Classes(i, wordMap.get(word));        wordcenter.put(i, word);        System.out.println(new String(word.getBytes("UTF-8")));         }    }        else{    if(point_path==null){        for (int i = 0; i < cArray.length; i++) {            Entry<String, double[]> next = iterator.next();            cArray[i] = new Classes(i, next.getValue());        }    }    else{    String word="";    File f = new File(point_path);    BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));    for (int i = 0; i < cArray.length; i++) {        word=br.readLine();        if(wordMap.containsKey(word)){        cArray[i] = new Classes(i, wordMap.get(word));        wordcenter.put(i, word);        System.out.println(new String(word.getBytes("UTF-8")));        }        else{        Entry<String, double[]> next = iterator.next();        cArray[i] = new Classes(i, next.getValue());        wordcenter.put(i, next.getKey());        }         }    br.close();    }    }                iterator = wordMap.entrySet().iterator();    HashMap<Integer,String> num_wordmap=new HashMap<Integer,String>();            HashMap<Integer,double[]> num_vecmap=new HashMap<Integer,double[]>();            //put word to the map            int count=0;            while (iterator.hasNext()) {            Entry<String, double[]> next = iterator.next();            num_wordmap.put(count, next.getKey());            num_vecmap.put(count, next.getValue());            count++;            }            //begin iterator step        for (int i = 0; i < iter; i++) {            for (Classes classes : cArray) {                classes.clean();            }                        double u[][]=new double[cArray.length][count];                        int cnt = 0;            int num=0;                        while (num<count) {            if(cnt % 10000 ==0)            {            System.out.println("Iter: "+i+"\tword:"+(cnt));            }                            double tempScore;                double d_sum=0;                double temp_sum=0;                int newid=0;                int flag=0;                                //calculate the total distances                for (Classes classes : cArray) {                 tempScore = classes.distance(num_vecmap.get(num));                 if(tempScore==0.0){                 flag=1;                 newid=classes.id;                 break;                 }                 temp_sum=Math.pow(1/tempScore, 1/(m-1));                 d_sum+=temp_sum;                }                if(flag==1){                 for (Classes classes : cArray) {                 u[classes.id][num]=0;                 cArray[classes.id].putValue(num_wordmap.get(num), 0);                 }                 u[newid][num]=1;                 cArray[newid].putValue(num_wordmap.get(num), 1);                                }                else{                //cArray is the cluster center point                for (Classes classes : cArray) {                //calculate the distance between the point and the center                    tempScore = classes.distance(num_vecmap.get(num));                    u[classes.id][num]=1/(Math.pow(tempScore,1/(m-1))*d_sum);//                    System.out.println(num+" to "+classes.id+" distances:" +tempScore);                    //put the num and its probability                    cArray[classes.id].putValue(num_wordmap.get(num), u[classes.id][num]);                }                }            cnt++;            num++;                                                }            System.out.println("Iter:"+i+"\tfinished\tword:"+(cnt));            for (Classes classes : cArray) {            classes.updateCenter(num_vecmap,u,count,m);              }            System.out.println("iter " + i + " ok!");        }        return cArray;    }        public static class Classes {        private int id;        private double[] center;        public Classes(int id, double[] center) {            this.id = id;            this.center = center.clone();        }        Map<String, Double> values = new HashMap<String,Double>();        //calculate the distance between point and center        public double distance(double[] value) {            double sum = 0;            for (int i = 0; i < value.length; i++) {                sum += (center[i] - value[i])*(center[i] - value[i]) ;            }            return sum ;        }                //put word and its probability        public void putValue(String word, double score) {            values.put(word, score);        }        public void updateCenter(HashMap<Integer, double[]> num_vecmap,double[][] u,int count,double m) {            for (int i = 0; i < center.length; i++) {                center[i] = 0;            }            double[] value = null;                        for(int j=0;j<count;j++){            value = num_vecmap.get(j);            for (int i = 0; i < value.length; i++) {                    center[i] +=Math.pow(u[id][j],m) *value[i];            }            }            double sum=0;            for(int j=0;j<count;j++){            sum+=Math.pow(u[id][j],m);            }                        for (int i = 0; i < center.length; i++) {                center[i] = center[i] / sum;            }        }        public void clean() {            values.clear();        }                public List<Entry<String, Double>> getMember() {            List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(                values.entrySet());            int count=arrayList.size();            if(count<=0){            return Collections.emptyList() ;          }            return arrayList;        }    }}


0 0