KNN分类器 适合数值型分类 java

来源:互联网 发布:淘宝的宝贝详情怎么弄 编辑:程序博客网 时间:2024/06/06 04:10
/**
 * KNN分类器
 * @author ysh 1208706282
 *
 */
public class KNN {
    Set<Integer> labelSet;
    List<Sample> samples;
    /**
     * 样本
     * @author Administrator
     *
     */
    static class Sample{
        int label;
        List<Double> feature;
    }
    static class SortSample implements java.lang.Comparable<SortSample> {
        int index;
        double distance;
        
        @Override
        public int compareTo(SortSample o) {
            // TODO Auto-generated method stub
            if(this.distance<o.distance){
                return 1;
            }
            return -1;
        }
    }
    //加载数据集
    public  void loadData(String path,String regex) throws Exception{
        labelSet = new HashSet<Integer>();
        samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.feature = new ArrayList<Double>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Double(splits[i]));
            }
            sample.label = Integer.valueOf(splits[splits.length-1]);
            labelSet.add(sample.label);
            samples.add(sample);
        }
        reader.close();
    }
    /**
     * 欧氏距离
     * @param src
     * @param dst
     * @return
     */
    public double euclideanDistance(Sample src,Sample dst){
        double distance = 0;
        for(int i=0;i<src.feature.size();i++){
            distance += (src.feature.get(i)-dst.feature.get(i))*(src.feature.get(i)-dst.feature.get(i));
        }
        return Math.sqrt(distance);
    }
    /**
     * 多数投票
     * @param list
     * @return
     */
    public int moreVote(List<Sample> list){
        int label = -1;
        Map<Integer,Integer> labelMap = new HashMap<Integer,Integer>();
        for(Sample sample:list){
            if(null == labelMap.get(sample.label)){
                labelMap.put(sample.label, 1);
            }else{
                labelMap.put(sample.label, labelMap.get(sample.label)+1);
            }
        }
        int max = -1;
        Iterator<Entry<Integer,Integer>> iter = labelMap.entrySet().iterator();
        Entry<Integer,Integer> entry = null;
        while(iter.hasNext()){
            entry = iter.next();
            if(entry.getValue()>max){
                max = entry.getValue();
                label = entry.getKey();
            }
        }
        return label;
    }
    /**
     * 样本分类     可做优化
     * @param sample
     * @param k
     * @return
     */
    public double classify(Sample sample,int k){
        if(samples.size()<k){
            return moreVote(samples);
        }
        
        Queue<SortSample> qnear = new PriorityQueue<SortSample>(k);
        SortSample s = null;
        for(int i=0;i<k;i++){
            s = new SortSample();
            s.index = i;
            s.distance = euclideanDistance(sample,samples.get(i));
            qnear.add(s);
        }
        for(int i=k;i<samples.size();i++){
            s = new SortSample();
            s.index = i;
            s.distance = euclideanDistance(sample,samples.get(i));
            qnear.add(s);
            qnear.poll();                 //保持队列只有k个近邻
        }
        
        List<Sample> listNear = new ArrayList<Sample>(k);
        for(SortSample temp:qnear){
            listNear.add(samples.get(temp.index));
        }
        return moreVote(listNear);
    }
    /**
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        // TODO Auto-generated method stub
        KNN knn = new KNN();
        knn.loadData("F:/2016-contest/abalone.csv",",");
        int count = 0;
        for(KNN.Sample sample:knn.samples){
            int ret = (int)knn.classify(sample, 11);
            if(ret == sample.label){
                count++;
            }
            System.out.println(ret+" "+sample.label);
        }
        System.out.println("right rate: "+(count*1.0/knn.samples.size()));
    }

}

0 0
原创粉丝点击