AdaBoost装袋提升算法

来源:互联网 发布:吉林大学网络教育查询 编辑:程序博客网 时间:2024/04/29 03:55

参开资料:http://blog.csdn.net/haidao2009/article/details/7514787
更多挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

介绍

在介绍AdaBoost算法之前,需要了解一个类似的算法,装袋算法(bagging),bagging是一种提高分类准确率的算法,通过给定组合投票的方式,获得最优解。比如你生病了,去n个医院看了n个医生,每个医生给你开了药方,最后的结果中,哪个药方的出现的次数多,那就说明这个药方就越有可能性是最由解,这个很好理解。而bagging算法就是这个思想。

算法原理

而AdaBoost算法的核心思想还是基于bagging算法,但是他又一点点的改进,上面的每个医生的投票结果都是一样的,说明地位平等,如果在这里加上一个权重,大城市的医生权重高点,小县城的医生权重低,这样通过最终计算权重和的方式,会更加的合理,这就是AdaBoost算法。AdaBoost算法是一种迭代算法,只有最终分类误差率小于阈值算法才能停止,针对同一训练集数据训练不同的分类器,我们称弱分类器,最后按照权重和的形式组合起来,构成一个组合分类器,就是一个强分类器了。算法的只要过程:

1、对D训练集数据训练处一个分类器Ci

2、通过分类器Ci对数据进行分类,计算此时误差率

3、把上步骤中的分错的数据的权重提高,分对的权重降低,以此凸显了分错的数据。为什么这么做呢,后面会做出解释。

完整的adaboost算法如下


最后的sign函数是符号函数,如果最后的值为正,则分为+1类,否则即使-1类。

我们举个例子代入上面的过程,这样能够更好的理解。

adaboost的实现过程:

  图中,“+”和“-”分别表示两种类别,在这个过程中,我们使用水平或者垂直的直线作为分类器,来进行分类。

  第一步:

  根据分类的正确率,得到一个新的样本分布D,一个子分类器h1

  其中划圈的样本表示被分错的。在右边的途中,比较大的“+”表示对该样本做了加权。

算法最开始给了一个均匀分布 D 。所以h1 里的每个点的值是0.1。ok,当划分后,有三个点划分错了,根据算法误差表达式得到 误差为分错了的三个点的值之和,所以ɛ1=(0.1+0.1+0.1)=0.3,而ɑ1 根据表达式 的可以算出来为0.42. 然后就根据算法 把分错的点权值变大。如此迭代,最终完成adaboost算法。

  第二步:

  根据分类的正确率,得到一个新的样本分布D3,一个子分类器h2

  第三步:

  得到一个子分类器h3

  整合所有子分类器:

  因此可以得到整合的结果,从结果中看,及时简单的分类器,组合起来也能获得很好的分类效果,在例子中所有的。后面的代码实现时,举出的也是这个例子,可以做对比,这里有一点比较重要,就是点的权重经过大小变化之后,需要进行归一化,确保总和为1.0,这个容易遗忘。

算法的代码实现

输入测试数据,与上图的例子相对应(数据格式:x坐标 y坐标 已分类结果):

[java] view plaincopyprint?
  1. 1 5 1  
  2. 2 3 1  
  3. 3 1 -1  
  4. 4 5 -1  
  5. 5 6 1  
  6. 6 4 -1  
  7. 6 7 1  
  8. 7 6 1  
  9. 8 7 -1  
  10. 8 2 -1  

Point.java

[java] view plaincopyprint?
  1. package DataMining_AdaBoost;  
  2.   
  3. /** 
  4.  * 坐标点类 
  5.  *  
  6.  * @author lyq 
  7.  *  
  8.  */  
  9. public class Point {  
  10.     // 坐标点x坐标  
  11.     private int x;  
  12.     // 坐标点y坐标  
  13.     private int y;  
  14.     // 坐标点的分类类别  
  15.     private int classType;  
  16.     //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等  
  17.     private double probably;  
  18.       
  19.     public Point(int x, int y, int classType){  
  20.         this.x = x;  
  21.         this.y = y;  
  22.         this.classType = classType;  
  23.     }  
  24.       
  25.     public Point(String x, String y, String classType){  
  26.         this.x = Integer.parseInt(x);  
  27.         this.y = Integer.parseInt(y);  
  28.         this.classType = Integer.parseInt(classType);  
  29.     }  
  30.   
  31.     public int getX() {  
  32.         return x;  
  33.     }  
  34.   
  35.     public void setX(int x) {  
  36.         this.x = x;  
  37.     }  
  38.   
  39.     public int getY() {  
  40.         return y;  
  41.     }  
  42.   
  43.     public void setY(int y) {  
  44.         this.y = y;  
  45.     }  
  46.   
  47.     public int getClassType() {  
  48.         return classType;  
  49.     }  
  50.   
  51.     public void setClassType(int classType) {  
  52.         this.classType = classType;  
  53.     }  
  54.   
  55.     public double getProbably() {  
  56.         return probably;  
  57.     }  
  58.   
  59.     public void setProbably(double probably) {  
  60.         this.probably = probably;  
  61.     }  
  62. }  
AdaBoost.java

[java] view plaincopyprint?
  1. package DataMining_AdaBoost;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.IOException;  
  7. import java.text.MessageFormat;  
  8. import java.util.ArrayList;  
  9. import java.util.HashMap;  
  10. import java.util.Map;  
  11.   
  12. /** 
  13.  * AdaBoost提升算法工具类 
  14.  *  
  15.  * @author lyq 
  16.  *  
  17.  */  
  18. public class AdaBoostTool {  
  19.     // 分类的类别,程序默认为正类1和负类-1  
  20.     public static final int CLASS_POSITIVE = 1;  
  21.     public static final int CLASS_NEGTIVE = -1;  
  22.   
  23.     // 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)  
  24.     public static final String CLASSIFICATION1 = "X=2.5";  
  25.     public static final String CLASSIFICATION2 = "X=7.5";  
  26.     public static final String CLASSIFICATION3 = "Y=5.5";  
  27.   
  28.     // 分类器组  
  29.     public static final String[] ClASSIFICATION = new String[] {  
  30.             CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 };  
  31.     // 分类权重组  
  32.     private double[] CLASSIFICATION_WEIGHT;  
  33.   
  34.     // 测试数据文件地址  
  35.     private String filePath;  
  36.     // 误差率阈值  
  37.     private double errorValue;  
  38.     // 所有的数据点  
  39.     private ArrayList<Point> totalPoint;  
  40.   
  41.     public AdaBoostTool(String filePath, double errorValue) {  
  42.         this.filePath = filePath;  
  43.         this.errorValue = errorValue;  
  44.         readDataFile();  
  45.     }  
  46.   
  47.     /** 
  48.      * 从文件中读取数据 
  49.      */  
  50.     private void readDataFile() {  
  51.         File file = new File(filePath);  
  52.         ArrayList<String[]> dataArray = new ArrayList<String[]>();  
  53.   
  54.         try {  
  55.             BufferedReader in = new BufferedReader(new FileReader(file));  
  56.             String str;  
  57.             String[] tempArray;  
  58.             while ((str = in.readLine()) != null) {  
  59.                 tempArray = str.split(" ");  
  60.                 dataArray.add(tempArray);  
  61.             }  
  62.             in.close();  
  63.         } catch (IOException e) {  
  64.             e.getStackTrace();  
  65.         }  
  66.   
  67.         Point temp;  
  68.         totalPoint = new ArrayList<>();  
  69.         for (String[] array : dataArray) {  
  70.             temp = new Point(array[0], array[1], array[2]);  
  71.             temp.setProbably(1.0 / dataArray.size());  
  72.             totalPoint.add(temp);  
  73.         }  
  74.     }  
  75.   
  76.     /** 
  77.      * 根据当前的误差值算出所得的权重 
  78.      *  
  79.      * @param errorValue 
  80.      *            当前划分的坐标点误差率 
  81.      * @return 
  82.      */  
  83.     private double calculateWeight(double errorValue) {  
  84.         double alpha = 0;  
  85.         double temp = 0;  
  86.   
  87.         temp = (1 - errorValue) / errorValue;  
  88.         alpha = 0.5 * Math.log(temp);  
  89.   
  90.         return alpha;  
  91.     }  
  92.   
  93.     /** 
  94.      * 计算当前划分的误差率 
  95.      *  
  96.      * @param pointMap 
  97.      *            划分之后的点集 
  98.      * @param weight 
  99.      *            本次划分得到的分类器权重 
  100.      * @return 
  101.      */  
  102.     private double calculateErrorValue(  
  103.             HashMap<Integer, ArrayList<Point>> pointMap) {  
  104.         double resultValue = 0;  
  105.         double temp = 0;  
  106.         double weight = 0;  
  107.         int tempClassType;  
  108.         ArrayList<Point> pList;  
  109.         for (Map.Entry entry : pointMap.entrySet()) {  
  110.             tempClassType = (int) entry.getKey();  
  111.   
  112.             pList = (ArrayList<Point>) entry.getValue();  
  113.             for (Point p : pList) {  
  114.                 temp = p.getProbably();  
  115.                 // 如果划分类型不相等,代表划错了  
  116.                 if (tempClassType != p.getClassType()) {  
  117.                     resultValue += temp;  
  118.                 }  
  119.             }  
  120.         }  
  121.   
  122.         weight = calculateWeight(resultValue);  
  123.         for (Map.Entry entry : pointMap.entrySet()) {  
  124.             tempClassType = (int) entry.getKey();  
  125.   
  126.             pList = (ArrayList<Point>) entry.getValue();  
  127.             for (Point p : pList) {  
  128.                 temp = p.getProbably();  
  129.                 // 如果划分类型不相等,代表划错了  
  130.                 if (tempClassType != p.getClassType()) {  
  131.                     // 划错的点的权重比例变大  
  132.                     temp *= Math.exp(weight);  
  133.                     p.setProbably(temp);  
  134.                 } else {  
  135.                     // 划对的点的权重比减小  
  136.                     temp *= Math.exp(-weight);  
  137.                     p.setProbably(temp);  
  138.                 }  
  139.             }  
  140.         }  
  141.   
  142.         // 如果误差率没有小于阈值,继续处理  
  143.         dataNormalized();  
  144.   
  145.         return resultValue;  
  146.     }  
  147.   
  148.     /** 
  149.      * 概率做归一化处理 
  150.      */  
  151.     private void dataNormalized() {  
  152.         double sumProbably = 0;  
  153.         double temp = 0;  
  154.   
  155.         for (Point p : totalPoint) {  
  156.             sumProbably += p.getProbably();  
  157.         }  
  158.   
  159.         // 归一化处理  
  160.         for (Point p : totalPoint) {  
  161.             temp = p.getProbably();  
  162.             p.setProbably(temp / sumProbably);  
  163.         }  
  164.     }  
  165.   
  166.     /** 
  167.      * 用AdaBoost算法得到的组合分类器对数据进行分类 
  168.      *  
  169.      */  
  170.     public void adaBoostClassify() {  
  171.         double value = 0;  
  172.         Point p;  
  173.   
  174.         calculateWeightArray();  
  175.         for (int i = 0; i < ClASSIFICATION.length; i++) {  
  176.             System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i]));  
  177.         }  
  178.           
  179.         for (int j = 0; j < totalPoint.size(); j++) {  
  180.             p = totalPoint.get(j);  
  181.             value = 0;  
  182.   
  183.             for (int i = 0; i < ClASSIFICATION.length; i++) {  
  184.                 value += 1.0 * classifyData(ClASSIFICATION[i], p)  
  185.                         * CLASSIFICATION_WEIGHT[i];  
  186.             }  
  187.               
  188.             //进行符号判断  
  189.             if (value > 0) {  
  190.                 System.out  
  191.                         .println(MessageFormat.format(  
  192.                                 "点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),  
  193.                                 p.getClassType()));  
  194.             } else {  
  195.                 System.out.println(MessageFormat.format(  
  196.                         "点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),  
  197.                         p.getClassType()));  
  198.             }  
  199.         }  
  200.     }  
  201.   
  202.     /** 
  203.      * 计算分类器权重数组 
  204.      */  
  205.     private void calculateWeightArray() {  
  206.         int tempClassType = 0;  
  207.         double errorValue = 0;  
  208.         ArrayList<Point> posPointList;  
  209.         ArrayList<Point> negPointList;  
  210.         HashMap<Integer, ArrayList<Point>> mapList;  
  211.         CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];  
  212.   
  213.         for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {  
  214.             mapList = new HashMap<>();  
  215.             posPointList = new ArrayList<>();  
  216.             negPointList = new ArrayList<>();  
  217.   
  218.             for (Point p : totalPoint) {  
  219.                 tempClassType = classifyData(ClASSIFICATION[i], p);  
  220.   
  221.                 if (tempClassType == CLASS_POSITIVE) {  
  222.                     posPointList.add(p);  
  223.                 } else {  
  224.                     negPointList.add(p);  
  225.                 }  
  226.             }  
  227.   
  228.             mapList.put(CLASS_POSITIVE, posPointList);  
  229.             mapList.put(CLASS_NEGTIVE, negPointList);  
  230.   
  231.             if (i == 0) {  
  232.                 // 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1  
  233.                 errorValue = calculateErrorValue(mapList);  
  234.             } else {  
  235.                 // 每次把上次计算所得的权重代入,进行概率的扩大或缩小  
  236.                 errorValue = calculateErrorValue(mapList);  
  237.             }  
  238.   
  239.             // 计算当前分类器的所得权重  
  240.             CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);  
  241.         }  
  242.     }  
  243.   
  244.     /** 
  245.      * 用各个子分类器进行分类 
  246.      *  
  247.      * @param classification 
  248.      *            分类器名称 
  249.      * @param p 
  250.      *            待划分坐标点 
  251.      * @return 
  252.      */  
  253.     private int classifyData(String classification, Point p) {  
  254.         // 分割线所属坐标轴  
  255.         String position;  
  256.         // 分割线的值  
  257.         double value = 0;  
  258.         double posProbably = 0;  
  259.         double negProbably = 0;  
  260.         // 划分是否是大于一边的划分  
  261.         boolean isLarger = false;  
  262.         String[] array;  
  263.         ArrayList<Point> pList = new ArrayList<>();  
  264.   
  265.         array = classification.split("=");  
  266.         position = array[0];  
  267.         value = Double.parseDouble(array[1]);  
  268.   
  269.         if (position.equals("X")) {  
  270.             if (p.getX() > value) {  
  271.                 isLarger = true;  
  272.             }  
  273.   
  274.             // 将训练数据中所有属于这边的点加入  
  275.             for (Point point : totalPoint) {  
  276.                 if (isLarger && point.getX() > value) {  
  277.                     pList.add(point);  
  278.                 } else if (!isLarger && point.getX() < value) {  
  279.                     pList.add(point);  
  280.                 }  
  281.             }  
  282.         } else if (position.equals("Y")) {  
  283.             if (p.getY() > value) {  
  284.                 isLarger = true;  
  285.             }  
  286.   
  287.             // 将训练数据中所有属于这边的点加入  
  288.             for (Point point : totalPoint) {  
  289.                 if (isLarger && point.getY() > value) {  
  290.                     pList.add(point);  
  291.                 } else if (!isLarger && point.getY() < value) {  
  292.                     pList.add(point);  
  293.                 }  
  294.             }  
  295.         }  
  296.   
  297.         for (Point p2 : pList) {  
  298.             if (p2.getClassType() == CLASS_POSITIVE) {  
  299.                 posProbably++;  
  300.             } else {  
  301.                 negProbably++;  
  302.             }  
  303.         }  
  304.           
  305.         //分类按正负类数量进行划分  
  306.         if (posProbably > negProbably) {  
  307.             return CLASS_POSITIVE;  
  308.         } else {  
  309.             return CLASS_NEGTIVE;  
  310.         }  
  311.     }  
  312.   
  313. }  
调用类Client.java:

[java] view plaincopyprint?
  1. /** 
  2.  * AdaBoost提升算法调用类 
  3.  * @author lyq 
  4.  * 
  5.  */  
  6. public class Client {  
  7.     public static void main(String[] agrs){  
  8.         String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";  
  9.         //误差率阈值  
  10.         double errorValue = 0.2;  
  11.           
  12.         AdaBoostTool tool = new AdaBoostTool(filePath, errorValue);  
  13.         tool.adaBoostClassify();  
  14.     }  
  15. }  

输出结果:

[java] view plaincopyprint?
  1. 分类器1权重为:0.424  
  2. 分类器2权重为:0.65  
  3. 分类器3权重为:0.923  
  4. 点(15)的组合分类结果为:1,该点的实际分类为1  
  5. 点(23)的组合分类结果为:1,该点的实际分类为1  
  6. 点(31)的组合分类结果为:-1,该点的实际分类为-1  
  7. 点(45)的组合分类结果为:-1,该点的实际分类为-1  
  8. 点(56)的组合分类结果为:1,该点的实际分类为1  
  9. 点(64)的组合分类结果为:-1,该点的实际分类为-1  
  10. 点(67)的组合分类结果为:1,该点的实际分类为1  
  11. 点(76)的组合分类结果为:1,该点的实际分类为1  
  12. 点(87)的组合分类结果为:-1,该点的实际分类为-1  
  13. 点(82)的组合分类结果为:-1,该点的实际分类为-1  

我们可以看到,如果3个分类单独分类,都没有百分百分对,而尽管组合结果之后,全部分类正确。

我对AdaBoost算法的理解

到了算法的末尾,有必要解释一下每次分类自后需要把错的点的权重增大,正确的减少的理由了,加入上次分类之后,(1,5)已经分错了,如果这次又分错,由于上次的权重已经提升,所以误差率更大,则代入公式ln(1-误差率/误差率)所得的权重越小,也就是说,如果同个数据,你分类的次数越多,你的权重越小,所以这就造成整体好的分类器的权重会越大,内部就会同时有各种权重的分类器,形成了一种互补的结果,如果好的分类器结果分错 ,可以由若干弱一点的分类器进行弥补。

AdaBoost算法的应用

可以运用在诸如特征识别,二分类的一些应用上,与单个模型相比,组合的形式能显著的提高准确率。

0 0
原创粉丝点击