KNN

来源:互联网 发布:蒙古表情 软件 编辑:程序博客网 时间:2024/04/30 00:37

KNN(K-NearestNeighbor)是分类算法中最简单的一种,用来计算特征的相似性。以电影评分系统为例,每个电影都会有一个评分向量,每部电影也都有一个类标签-动作、爱情等。通过KNN算法可以计算出不同电影之间的评分向量的距离,以此来判断不同电影间的相似性,当有一部新电影进来时,就可以将其归为最相似电影所属的那一类。本文首先介绍KNN(K-近邻)算法的原理,然后给出其实现的伪代码,最后结合具体实例,给出java实现代码。


  • KNN算法的原理:

这里写图片描述

抽象为如下图像,判断“?”属于哪个类标签。首先找离它最近的k个类标签,然后看这k各类标签中哪个类别出现的频率最高,根据少数服从多数的原则,“?”就属于哪个类别。如下图,当K取4时,离其最近的4的标签是一个蓝方块,一个绿圆和两个红三角,这4各类别中红三角出现频率最高,那么“?”就应该属于红三角类别。同理,K若取5,“?”应属于蓝色方块类别。由此可见,不同的K值,会有不同的结果,我们要谨慎选择K值,可以通过交叉验证选择效果最好的K值。(李航的《统计学导论》有说明,K取小了会产生过拟合,取大了会欠拟合)

这里写图片描述


  • 伪代码:

目标:判断测试数据集中的类标签。

  1. 计算已知类别数据集中的向量与当前向量之间的距离
  2. 按照距离递增次序排序
  3. 选取与当前点距离最小的K个点
  4. 确定前k个点所在类别的出现频率
  5. 返回前k各点中出现频率最高的类别作为当前点的预测分类。

  • 可运行java代码:
package knn;import java.util.*;public class KNN    {    // the data    static double[][] instances = {    {0.35,0.91,0.86,0.42,0.71},    {0.21,0.12,0.76,0.22,0.92},    {0.41,0.58,0.73,0.21,0.09},    {0.71,0.34,0.55,0.19,0.80},    {0.79,0.45,0.79,0.21,0.44},    {0.61,0.37,0.34,0.81,0.42},    {0.78,0.12,0.31,0.83,0.87},    {0.52,0.23,0.73,0.45,0.78},    {0.53,0.17,0.63,0.29,0.72},    };    private static String findMajorityClass(String[] array)    {    Set<String> h = new HashSet<String>(Arrays.asList(array));//ss原先是字符串数组    String[] uniqueValues = h.toArray(new String[0]);    int[] counts = new int[uniqueValues.length];    for (int i = 0; i < uniqueValues.length; i++) {    for (int j = 0; j < array.length; j++) {    if(array[j].equals(uniqueValues[i])){    counts[i]++;    }    }    }    for (int i = 0; i < uniqueValues.length; i++)    System.out.println(uniqueValues[i]);    for (int i = 0; i < counts.length; i++)    System.out.println(counts[i]);//考虑了出现多类别频率相同的情况,这部分是在找最大频率。    int max = counts[0];    for (int counter = 1; counter < counts.length; counter++) {    if (counts[counter] > max) {    max = counts[counter];    }    }    System.out.println("max # of occurences: "+max);    // how many times max appears    //we know that max will appear at least once in counts    //so the value of freq will be 1 at minimum after this loop    int freq = 0;    for (int counter = 0; counter < counts.length; counter++) {    if (counts[counter] == max) {    freq++;    }    }    int index = -1;    if(freq==1){    for (int counter = 0; counter < counts.length; counter++) {    if (counts[counter] == max) {    index = counter;    break;    }    }    return uniqueValues[index];//返回类别    }     else{//we have multiple modes    int[] ix = new int[freq];//array of indices of modes    System.out.println("multiple majority classes: "+freq+" classes");    int ixi = 0;    for (int counter = 0; counter < counts.length; counter++) {    if (counts[counter] == max) {    ix[ixi] = counter;//save index of each max count value    ixi++; // increase index of ix array    }    }    for (int counter = 0; counter < ix.length; counter++)    System.out.println("class index: "+ix[counter]);    //now choose one at random    Random generator = new Random();    //get random number 0 <= rIndex < size of ix    int rIndex = generator.nextInt(ix.length);    System.out.println("random index: "+rIndex);    int nIndex = ix[rIndex];    //return unique value at that index    return uniqueValues[nIndex];    }    }    public static void main(String args[]){    int k = 6;// # of neighbours    //list to save city data    List<City> cityList = new ArrayList<City>();    //list to save distance result    List<Result> resultList = new ArrayList<Result>();    // add city data to cityList    cityList.add(new City(instances[0],"London"));    cityList.add(new City(instances[1],"Leeds"));    cityList.add(new City(instances[2],"Liverpool"));    cityList.add(new City(instances[3],"London"));    cityList.add(new City(instances[4],"Liverpool"));    cityList.add(new City(instances[5],"Leeds"));    cityList.add(new City(instances[6],"London"));    cityList.add(new City(instances[7],"Liverpool"));    cityList.add(new City(instances[8],"Leeds"));    //data about unknown city    double[] query = {0.65,0.78,0.21,0.29,0.58};    //find disnaces for循环结束后,resultList里保存的是测试集到各标签的距离    for(City city : cityList){//for循环标签    double dist = 0.0;    for(int j = 0; j < city.cityAttributes.length; j++){//属性就是那些值    dist += Math.pow(city.cityAttributes[j] - query[j], 2) ;//pow的用法    //System.out.print(city.cityAttributes[j]+" ");    }    double distance = Math.sqrt( dist );    resultList.add(new Result(distance,city.cityName));//是result类型的,相当于python中的字典。java种也有字典吧    //System.out.println(distance);    }    //System.out.println(resultList);    Collections.sort(resultList, new DistanceComparator());//sort     String[] ss = new String[k];//取top-6    for(int x = 0; x < k; x++){    System.out.println(resultList.get(x).cityName+ " .... " + resultList.get(x).distance);    //get classes of k nearest instances (city names) from the list into an array    ss[x] = resultList.get(x).cityName;//要计算类别频率    }    String majClass = findMajorityClass(ss);//传递过去的是标签    System.out.println("Class of new instance is: "+majClass);    }//end main    //simple class to model instances (features + class)    static class City {    double[] cityAttributes;    String cityName;    public City(double[] cityAttributes, String cityName){    this.cityName = cityName;    this.cityAttributes = cityAttributes;    }    }    //simple class to model results (distance + class)    static class Result {    double distance;    String cityName;    public Result(double distance, String cityName){    this.cityName = cityName;    this.distance = distance;    }    }    //simple comparator class used to compare results via distances    static class DistanceComparator implements Comparator<Result> {    @Override    public int compare(Result a, Result b) {    return a.distance < b.distance ? -1 : a.distance == b.distance ? 0 : 1;    }    }    }
0 0
原创粉丝点击