KNN分类算法java实现

来源:互联网 发布:女朋友 文艺青年 知乎 编辑:程序博客网 时间:2024/05/22 10:32

最近邻分类算法思想

KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

1)计算测试数据与各个训练数据之间的距离;

2)按照距离的递增关系进行排序;

3)选取距离最小的K个点;

4)确定前K个点所在类别的出现频率;

5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。

Java代码实现

KNN.java代码

public class KNN {

          public static void main(String[] args) {

                

                // 一、输入所有已知点

                List<Point>dataList =creatDataSet();

                // 二、输入未知点

                Point x = new Point(5, 1.2, 1.2);

                // 三、计算所有已知点到未知点的欧式距离,并根据距离对所有已知点排序

                CompareClass compare = new CompareClass();

                Set<Distance> distanceSet = new TreeSet<Distance>(compare);

                for (Pointpoint :dataList) {

                     distanceSet.add(new Distance(point.getId(),x.getId(), oudistance(point,

                             x)));

                }

                // 四、选取最近的k个点

                double k = 5;

                

                /**

                 * 五、计算k个点所在分类出现的频率

                 */

                // 1、计算每个分类所包含的点的个数

                List<Distance> distanceList= new ArrayList<Distance>(distanceSet);

                Map<String, Integer> map = getNumberOfType(distanceList, dataList, k);

                

                // 2、计算频率

                Map<String, Double> p = computeP(map, k);

                

                x.setType(maxP(p));

                System.out.println("未知点的类型为:"+x.getType());

            }

 

            // 欧式距离计算

            public static double oudistance(Pointpoint1, Pointpoint2) {

                double temp = Math.pow(point1.getX() -point2.getX(), 2)

                         + Math.pow(point1.getY() -point2.getY(), 2);

                return Math.sqrt(temp);

            }

 

            // 找出最大频率

            public static String maxP(Map<String,Double> map) {

                String key = null;

                double value = 0.0;

                for (Map.Entry<String, Double> entry :map.entrySet()) {

                     if (entry.getValue() >value) {

                         key = entry.getKey();

                         value = entry.getValue();

                     }

                }

                return key;

            }

 

            // 计算频率

            public static Map<String,Double> computeP(Map<String, Integer> map,

                     double k) {

                Map<String, Double> p = new HashMap<String, Double>();

                for (Map.Entry<String, Integer> entry :map.entrySet()) {

                     p.put(entry.getKey(),entry.getValue() / k);

                }

                return p;

            }

 

            // 计算每个分类包含的点的个数

            public static Map<String,Integer> getNumberOfType(

                     List<Distance> listDistance, List<Point>listPoint, double k) {

                Map<String, Integer> map = new HashMap<String, Integer>();

                int i = 0;

                System.out.println("选取的k个点,由近及远依次为:");

                for (Distance distance : listDistance) {

                     System.out.println("id" +distance.getId() + ",距离为:"

                             + distance.getDisatance());

                     long id = distance.getId();

                     // 通过id找到所属类型,并存储到HashMap

                     for (Point point : listPoint) {

                         if (point.getId() ==id) {

                             if (map.get(point.getType()) != null)

                                map.put(point.getType(),map.get(point.getType()) + 1);

                             else {

                                 map.put(point.getType(), 1);

                             }

                         }

                     }

                     i++;

                     if (i >= k)

                         break;

                }

                return map;

            }

            

            public static ArrayList<Point> creatDataSet(){

                

                Point point1 = new Point(1, 1.0, 1.1, "A");

                Point point2 = new Point(2, 1.0, 1.0, "A");

                Point point3 = new Point(3, 1.0, 1.2, "A");

                Point point4 = new Point(4, 0, 0, "B");

                Point point5 = new Point(5, 0, 0.1, "B");

                Point point6 = new Point(6, 0, 0.2, "B");

                

                ArrayList<Point>dataList = new ArrayList<Point>();

                dataList.add(point1);

                dataList.add(point2);

                dataList.add(point3);

                dataList.add(point4);

                dataList.add(point5);

                dataList.add(point6);

                

                return dataList;

            }

}

 

类中涉及到的Point类,Distance类,比较裁判CompareClass类如下:

Point

 

publicclass Point {

    privatelong id;

    privatedouble x;

    privatedouble y;

    private String type;

   

    public Point(long id,double x,double y) {

        this.x =x;

        this.y =y;

        this.id =id;

    }

 

    public Point(long id,double x,double y, String type) {

        this.x =x;

        this.y =y;

        this.type= type;

        this.id =id;

    }

 

    //get、set方法省略

}

 

Distance

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

public class Distance {

    // 已知点id

    private long id;

    // 未知点id

    private long nid;

    // 二者之间的距离

    private double disatance;

 

     

     

    public Distance(long id, long nid, double disatance) {

        this.id = id;

        this.nid = nid;

        this.disatance = disatance;

    }

        

       //get、set方法省略

 

}

比较器CompareClass

1

2

3

4

5

6

7

8

9

import java.util.Comparator;

//比较器类

public class CompareClass implements Comparator<Distance>{

 

    public int compare(Distance d1, Distance d2) {

        return d1.getDisatance()>d2.getDisatance()?20 : -1;

    }

 

}

 

其中的计算Map<String,Double> typeAndDistance按照distance进行排序,也就是按照map的value进行排序。思路也可以用如下方法:

1.  public class Testing {  

2.   

3.      public static void main(String[] args) {  

4.   

5.          HashMap<String,Double> map = new HashMap<String,Double>();  

6.         ValueComparator bvc =  new ValueComparator(map);  

7.          TreeMap<String,Double> sorted_map = new TreeMap<String,Double>(bvc);  

8.   

9.          map.put("A",99.5);  

10.        map.put("B",67.4);  

11.         map.put("C",67.4);  

12.        map.put("D",67.3);  

13.   

14.        System.out.println("unsorted map: "+map);  

15.   

16.        sorted_map.putAll(map);  

17.   

18.        System.out.println("results: "+sorted_map);  

19.     }  

20.}  

21.   

22.class ValueComparator implements Comparator<String> {  

23.   

24.    Map<String, Double> base;  

25.     public ValueComparator(Map<String, Double> base) {  

26.        this.base = base;  

27.     }  

28.  

29.     // Note: this comparator imposes orderings that are inconsistent with equals.      

30.    public int compare(String a, String b) {  

31.         if (base.get(a) >= base.get(b)) {  

32.            return -1;  

33.         } else {  

34.            return 1;  

35.         } // returning 0 would merge keys  

36.    }  

37. }  

 

 

原创粉丝点击