【Java实现KNN算法】KNN(k邻近)详解与java实现

来源:互联网 发布:淘宝面膜代理 编辑:程序博客网 时间:2024/05/01 00:17

            1.KNN算法

       1.KNN算法

        KNN法最初由Cover 和Hart 于1968 年提出, 是一个理论上比较成熟的方法。KNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的K(K=1,2,3…,n,其中,n<=D)个样本的类别来决定待分样本所属的类别。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合

     2.步骤

      KNN算法是比较简单的,理解起来也不难,具体步骤如下图:

      上图中D表示训练集,其中P表示特征值Ti所描述的的对象(或者类型);Xik表示Xi的第k个特征值(Wk类似);GroupBy()对p进行分类统计(小编借用SQL语言里面的关键字,哈哈)。哦!对了,k的取值是有学问,具体如何取,取多少,就要你针对你研究的问题多试验几次,或者你是个老司机,早已摸清待研究问题的套路。大概思路就是上面这5步,大家可以一步两步,一步两步,是魔鬼的步伐浪起来~~~~~

        3.java实现

       根据上述的五个步骤,分步实现。在没开始前,首先要对研究对象进行抽象,具体如下:
public class KNNnode implements Comparable<KNNnode>{/** * 实现comparable接口重写compareTo()方法 * 目的:方便存放KNNnode对象的List进行排序,排序的目标属性为l(即与待测点距离) */float x1,x2;  //特征值String type;  //特征值对应的类型double l ;     //与待预测点的距离public float getX1() {return x1;}public void setX1(float x1) {this.x1 = x1;}public float getX2() {return x2;}public void setX2(float x2) {this.x2 = x2;}public String getType() {return type;}public void setType(String type) {this.type = type;}public double getL() {return l;}public void setL(double l) {this.l = l;}@Overridepublic String toString() {return "KNNnode [x1=" + x1 + ", x2=" + x2 + ", type=" + type + ", l="+ l + "]";}//从大到小排列@Overridepublic int compareTo(KNNnode o) {// TODO Auto-generated method stubif(l>o.getL())return -1;if(l<o.getL())return 1;return 0;}}
        上面抽象的类,描述的对象有两个特征值,即X1,X2,type是特征值对应的类型。现在就可以用KNNnode来描述或者用历史数据进行生成对象了。基础已经打好了,直接进入五步法了:
          ①加载样本数据进行训练
         /** * 从txt文本中读取KNNnode所需数据并存放在List中 * @param url  txt文本存放路径 * @return   */public List<KNNnode> ReadKNNnodeFromFile(String url){List<KNNnode> node = new ArrayList<KNNnode>();String st = "";File file = new File(url);if(file.isFile()){try{BufferedReader reader = new BufferedReader(new FileReader(file));while((st=reader.readLine())!=null){//用空格对字符串进行分割,s+可以匹配多个空格String val[] = st.split("\\s+");if(val.length==3){KNNnode knNnode = new KNNnode();knNnode.setX1(Float.parseFloat(val[0]));knNnode.setX2(Float.parseFloat(val[1]));knNnode.setType(val[2]);node.add(knNnode);}}reader.close();}catch(IOException e){e.printStackTrace();}catch (Exception e) {// TODO: handle exceptione.printStackTrace();}}else{System.out.println("文件不存在!");}return node;}
        txt文本存放的数据是以每个对象里面的属性为一行,如下图:

        使用IO流从文本中获取数据来生成对象非常方便,只需要按照上面这种模式,一个循环就可以依次读出数据生成KNNnode类并存放在List里面。
        ②确定K以及待分类(确定/预测)的对象。这个比较简单。
                final int K = 3;                KNNnode kn1 = new KNNnode();kn1.setX1(22);kn1.setX2(17);
        ③计算距离(相似度)
/** * 按照欧式距离公式,计算历史数据与待预测对象之间的距离 * @param node l;训练集样本对象 * @param kn1    待预测 * @return */public List<KNNnode> calcul(List<KNNnode> node,KNNnode kn1){for(int i=0;i<node.size();i++){KNNnode kn2 = node.get(i);kn2.setL(Math.sqrt(Math.pow(kn1.getX1()-kn2.getX1(), 2)+Math.pow(kn1.getX2()-kn2.getX2(), 2)));}return node;}
         在抽象对象时,有一个属性l是专门为了存放训练样本与带预测对象之间的距离或者相似度。
         ④返回距离(相似度)最近(高)的K个对象
/** * 对k个KNNnode的类型进行分类统计 * 使用Map,借助map键值对的存储方式,所以非常方便 * @param node * @param k * @return */public Map<String,List<KNNnode>> result(List<KNNnode> node,int k){Map<String,List<KNNnode>> knnmap = new HashMap<String,List<KNNnode>>();System.out.println("---------------------K个最小的KNNnode对象-------------------");for(int i=0;i<node.size();i++){System.out.println(node.get(i).toString());}for(int i=0;i<k;i++){String type = node.get(i).getType().trim();if(knnmap.containsKey(type)){knnmap.get(type).add(node.get(i));Collections.sort(knnmap.get(type));}else{knnmap.put(type, new ArrayList<KNNnode>());knnmap.get(type).add(node.get(i));}}return knnmap;}
          其次,待预测的type值就等于频数高的type。
public static void main(String args[]) {final String file_url = "D:\\knn.txt";final int K = 5;String type = "";List<KNNnode> node = new ArrayList<KNNnode>();// 待归类的nodeKNNnode kn1 = new KNNnode();kn1.setX1(33);kn1.setX2(12);KNNprocess knNprocess = new KNNprocess();node = knNprocess.ReadKNNnodeFromFile(file_url);node = knNprocess.calcul(node, kn1);node = knNprocess.getnodeDESC(node, K);double l = node.get(0).getL();int s = 0;Map<String, List<KNNnode>> knn = knNprocess.result(node, K);for(Map.Entry<String, List<KNNnode>> en:knn.entrySet()){int s1= en.getValue().size();if(s1>s){l = ((KNNnode)(en.getValue().get(s1-1))).getL();s = s1;type = en.getKey();}elseif(s1==s&l>((KNNnode)(en.getValue().get(s1-1))).getL()){l = ((KNNnode)(en.getValue().get(s1-1))).getL();s = s1;type = en.getKey();}}System.out.println("---------------------------预测结果-------------------------");kn1.setType(type);System.out.println(kn1.toString());}
        综合上述5步,已经实现了待预测对象的归类了问题了。来看下我跑程序的结果:


          OK!!KNN算法使用java已经实现。

         4.KNN可以解决的问题

        当然,你看该算法肯定是要用的,怎么用你也知道。那先来看下学术界用KNN算法都解决什么问题。



         上图是我从知网上搜索截的图(冰山一角),KNN算法应用十分的广泛,而且还对KNN进行了改进,使用了权重进行算法优化,或者和其他算法进行了组合实用,是目前比较简单而且应用比较广的有监督的机器学习方法。
demo下载:KNN算法java实现demo

1 0
原创粉丝点击