机器学习算法——KNN分类算法介绍以及Java实现

来源:互联网 发布:太原科大网络 编辑:程序博客网 时间:2024/05/17 00:13

KNN分类算法介绍

一、什么是分类

分类是指通过对大量的训练样本进行提取和分析,训练出用来分类的规则,即分类器或者分类模型,最终判断未知样本的类别。常见的分类算法有:决策树(ID3和C4.5),朴素贝叶斯,人工神经网络 (Artificial Neural Networks,ANN),k-近邻(kNN),支持向量机(SVM),基于关联规则的分类,Adaboosting方法等等。这篇文章主要介绍KNN算法。

二、KNN算法原理

1 原理

KNN算法又称为K近邻算法,根据训练样本和样本类别,计算与待分类样本相似度最大的K个训练点,然后对这K个训练点进行投票并排序,选择投票数最高的样本类别作为待分类数据的类别。这里的相似性度量可采用欧氏距离、马氏距离,余弦相似度等等。K为人为设定,一般选择奇数。

2 算法优点

1、算法简单、有效,通常用于文本分类。
2、重新训练的代价较低(类别体系的变化和训练集的变化,在Web环境和电子商务应用中是很常见的)。
3、该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。

3 算法缺点

1、K值得选择对算法精度影响较大。
2、依赖于相似性度量的优劣。
3、当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个 邻居中大容量类的样本占多数。
4、计算量较大。

4 KNN算法实例

下图中有两种类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形,中间那个绿色的圆形是待分类数据;

如果K=3,那么离绿色点最近的有2个红色的三角形和1个蓝色的正方形,这三个点进行投票,于是绿色的待分类点就属于红色的三角形。

如果K=5,那么离绿色点最近的有2个红色的三角形和3个蓝色的正方形,这五个点进行投票,于是绿色的待分类点就属于蓝色的正方形。

三、KNN算法描述

KNN算法的步骤可以描述为:

1、计算出样本数据和待分类数据的距离;

2、为待分类数据选择K个与其距离最小的样本;

3、统计出K个样本中大多数样本所属的分类;

4、这个分类就是待分类数据所属的分类。

四、KNN算法Java实现

// TODO Auto-generated method stub  //首先读取训练样本和测试样本,用map<String,map<word,TF>>保存测试集和训练集,注意训练样本的类目信息也得保存,  //然后遍历测试样本,对于每一个测试样本去计算它与所有训练样本的相似度,相似度保存入map<String,double>有  //序map中去,然后取前K个样本,针对这k个样本来给它们所属的类目计算权重得分,对属于同一个类目的权重求和进而得到  //最大得分的类目,就可以判断测试样例属于该类目下,K值可以反复测试,找到分类准确率最高的那个值  //!注意要以"类目_文件名"作为每个文件的key,才能避免同名不同内容的文件出现  //!注意设置JM参数,否则会出现JAVA heap溢出错误  //!本程序用向量夹角余弦计算相似度  public static double doProcess(String trainFiles, String testFiles, String kNNResultFile) throws IOException {        System.out.println("开始训练模型:");        File trainSamples = new File(trainFiles);        BufferedReader trainSamplesBR = new BufferedReader(new FileReader(trainSamples));        String line;        String[] lineSplitBlock;        Map<String, TreeMap<String, Double>> trainFileNameWordTFMap = new TreeMap<String, TreeMap<String, Double>>();        TreeMap<String, Double> trainWordTFMap = new TreeMap<String, Double>();        int index1 = 0;        while ((line = trainSamplesBR.readLine()) != null) {            index1++;            lineSplitBlock = line.split(" ");            trainWordTFMap.clear();            for (int i = 1; i < lineSplitBlock.length; i = i + 2) {                trainWordTFMap.put(lineSplitBlock[i], Double.valueOf(lineSplitBlock[i + 1]));            }            TreeMap<String, Double> tempMap = new TreeMap<String, Double>();            tempMap.putAll(trainWordTFMap);            trainFileNameWordTFMap.put(lineSplitBlock[0] + "_" + index1, tempMap);        }        trainSamplesBR.close();        File testSamples = new File(testFiles);        BufferedReader testSamplesBR = new BufferedReader(new FileReader(testSamples));        Map<String, Map<String, Double>> testFileNameWordTFMap = new TreeMap<String, Map<String, Double>>();        Map<String, String> testClassifyCateMap = new TreeMap<String, String>();//分类形成的<文件名,类目>对          Map<String, Double> testWordTFMap = new TreeMap<String, Double>();        int index = 0;        while ((line = testSamplesBR.readLine()) != null) {            index++;            lineSplitBlock = line.split(" ");            testWordTFMap.clear();            for (int i = 1; i < lineSplitBlock.length; i = i + 2) {                testWordTFMap.put(lineSplitBlock[i], Double.valueOf(lineSplitBlock[i + 1]));            }            TreeMap<String, Double> tempMap = new TreeMap<String, Double>();            tempMap.putAll(testWordTFMap);            testFileNameWordTFMap.put(lineSplitBlock[0] + "_" + index, tempMap);        }        testSamplesBR.close();        //下面遍历每一个测试样例计算与所有训练样本的距离,做分类          String classifyResult;        FileWriter testYangliuWriter = new FileWriter(new File("D:\\DataMining\\Title\\yangliuTest.txt"));        FileWriter KNNClassifyResWriter = new FileWriter(kNNResultFile);        Set<Map.Entry<String, Map<String, Double>>> testFileNameWordTFMapSet = testFileNameWordTFMap.entrySet();        for (Iterator<Map.Entry<String, Map<String, Double>>> it = testFileNameWordTFMapSet.iterator(); it.hasNext();) {            Map.Entry<String, Map<String, Double>> me = it.next();            classifyResult = KNNComputeCate(me.getKey(), me.getValue(), trainFileNameWordTFMap, testYangliuWriter);            System.out.println("分类结果为:"+ classifyResult+";正确结果为:"+me.getKey());            KNNClassifyResWriter.append(me.getKey() + " " + classifyResult + "\n");            KNNClassifyResWriter.flush();            testClassifyCateMap.put(me.getKey(), classifyResult);        }        KNNClassifyResWriter.close();        //计算分类的准确率          double righteCount = 0;        Set<Map.Entry<String, String>> testClassifyCateMapSet = testClassifyCateMap.entrySet();        for (Iterator<Map.Entry<String, String>> it = testClassifyCateMapSet.iterator(); it.hasNext();) {            Map.Entry<String, String> me = it.next();            String rightCate = me.getKey().split("_")[0];            if (me.getValue().equals(rightCate)) {                righteCount++;            }        }        testYangliuWriter.close();        return righteCount / testClassifyCateMap.size();    }    /**     * 对于每一个测试样本去计算它与所有训练样本的向量夹角余弦相似度 相似度保存入map<String,double>有序map中去,然后取前K个样本,     * 针对这k个样本来给它们所属的类目计算权重得分,对属于同一个类 目的权重求和进而得到最大得分的类目,就可以判断测试样例属于该     * 类目下。K值可以反复测试,找到分类准确率最高的那个值     *     * @param testWordTFMap 当前测试文件的<单词,词频>向量     * @param trainFileNameWordTFMap 训练样本<类目_文件名,向量>Map     * @param testYangliuWriter     * @return String K个邻居权重得分最大的类目     * @throws IOException     */    public static String KNNComputeCate(            String testFileName,            Map<String, Double> testWordTFMap,            Map<String, TreeMap<String, Double>> trainFileNameWordTFMap, FileWriter testYangliuWriter) throws IOException {        // TODO Auto-generated method stub          HashMap<String, Double> simMap = new HashMap<String, Double>();//<类目_文件名,距离> 后面需要将该HashMap按照value排序          double similarity;        Set<Map.Entry<String, TreeMap<String, Double>>> trainFileNameWordTFMapSet = trainFileNameWordTFMap.entrySet();        for (Iterator<Map.Entry<String, TreeMap<String, Double>>> it = trainFileNameWordTFMapSet.iterator(); it.hasNext();) {            Map.Entry<String, TreeMap<String, Double>> me = it.next();            similarity = computeSim(testWordTFMap, me.getValue());            simMap.put(me.getKey(), similarity);        }        //下面对simMap按照value排序          ByValueComparator bvc = new ByValueComparator(simMap);        TreeMap<String, Double> sortedSimMap = new TreeMap<String, Double>(bvc);        sortedSimMap.putAll(simMap);        //在disMap中取前K个最近的训练样本对其类别计算距离之和,K的值通过反复试验而得          Map<String, Double> cateSimMap = new TreeMap<String, Double>();//K个最近训练样本所属类目的距离之和          double K = 15;        double count = 0;        double tempSim;        Set<Map.Entry<String, Double>> simMapSet = sortedSimMap.entrySet();        for (Iterator<Map.Entry<String, Double>> it = simMapSet.iterator(); it.hasNext();) {            Map.Entry<String, Double> me = it.next();            count++;            String categoryName = me.getKey().split("_")[0];            if (cateSimMap.containsKey(categoryName)) {                tempSim = cateSimMap.get(categoryName);                cateSimMap.put(categoryName, tempSim + me.getValue());            } else {                cateSimMap.put(categoryName, me.getValue());            }            if (count > K) {                break;            }        }        //下面到cateSimMap里面把sim最大的那个类目名称找出来          //testYangliuWriter.flush();          //testYangliuWriter.close();          double maxSim = 0;        String bestCate = null;        Set<Map.Entry<String, Double>> cateSimMapSet = cateSimMap.entrySet();        for (Iterator<Map.Entry<String, Double>> it = cateSimMapSet.iterator(); it.hasNext();) {            Map.Entry<String, Double> me = it.next();            if (me.getValue() > maxSim) {                bestCate = me.getKey();                maxSim = me.getValue();            }        }        return bestCate;    }    /**     * 计算测试样本向量和训练样本向量的相似度     *     * @param testWordTFMap 当前测试文件的<单词,词频>向量     * @param trainWordTFMap 当前训练样本<单词,词频>向量     * @return Double 向量之间的相似度 以向量夹角余弦计算     * @throws IOException     */    public static double computeSim(Map<String, Double> testWordTFMap,            Map<String, Double> trainWordTFMap) {        // TODO Auto-generated method stub          double mul = 0, testAbs = 0, trainAbs = 0;        Set<Map.Entry<String, Double>> testWordTFMapSet = testWordTFMap.entrySet();        for (Iterator<Map.Entry<String, Double>> it = testWordTFMapSet.iterator(); it.hasNext();) {            Map.Entry<String, Double> me = it.next();            if (trainWordTFMap.containsKey(me.getKey())) {                mul += me.getValue() * trainWordTFMap.get(me.getKey());            }            testAbs += me.getValue() * me.getValue();        }        testAbs = Math.sqrt(testAbs);        Set<Map.Entry<String, Double>> trainWordTFMapSet = trainWordTFMap.entrySet();        for (Iterator<Map.Entry<String, Double>> it = trainWordTFMapSet.iterator(); it.hasNext();) {            Map.Entry<String, Double> me = it.next();            trainAbs += me.getValue() * me.getValue();        }        trainAbs = Math.sqrt(trainAbs);        return mul / (testAbs * trainAbs);    }

五、结果展示

KNN算法用于对基建数据的行业分类,所以得到如下分类结果;分类正确率为80%左右,算法还需继续改进。后面的博客会完整介绍KNN用于文本分类的具体处理,敬请期待!!!