C45算法代码实现及其详解

来源:互联网 发布:java 自定义报表 编辑:程序博客网 时间:2024/05/22 08:13

1.概述

C45算法在weka已经有具体的实现,即weka中的J48.java。不过J48.java中的具体代码牵扯到较多的类和其他东西,直接看源代码比较容易混乱,且需要了解的东西较多,有比较多和C45算法本身没有太大关系而是为了方便代码实现的类、变量和方法等。
本文是基于C45算法思想和对J48源代码的详细解读,自己写了一个C45算法的代码(之后均称为MyC45)。该代码只含有两个类(99%的代码只在一个类中实现),需要了解的结构相对简单,算法的实现过程相对清晰,算法的效果和J48.java相差无几,写在这里权作参考。
weka中的每个分类器的类文件都实现了buildClassifier和distributionForInstance两个算法,前者是构建分类器,后者是根据构建的分类器对每个实例进行类标记预测。所以,MyC45算法中主要分为建树(buildClassifier)和预测(distributionForInstance)两部分,其中建树(buildClassifier)是主体部分,预测(distributionForInstance)只是简单的利用前面构建好的树对每个实例进行预测。
特别说明,本文所使用的实例集均已进行数据预处理,所以在MyC45中没有J48中对缺失值的处理、判断J48能否处理该实例集等等对实例集的处理。

2.类及其变量说明

MyC45包括两个类,一个是MyC45,另一个是MyClassifierTree,它们的具体功能和变量说明如下:
1.MyC45:实现C45算法的类。
变量:
m_Root:C45决策树的根节点。
2.MyClassifierTree:具体实现C45算法中的建树和预测实例类标记两大部分的功能。
变量:
int[] m_AttributeList: 当前节点的可选分裂属性集合,以整数形式保存.-1表示当前节点不能选择该属性作为分裂属性。
int m_SplitAttribute:当前节点的分裂属性,以整数的形式表示。
int m_NumAttributes:训练实例集中属性的个数(包括类属性)。
MyCLassifierTree[] m_Sons:当前节点的儿子节点。
Instances m_Instances:当前节点对应的实例集。
int m_MinInstances:最小实例数,若当前节点对应的实例集的实例数小于该值,则当前节点只能作为叶节点。

注:以上是关键变量说明,类中其他变量并不是很重要就没有特别说明,在后面的代码中出现时会有说明。

3.伪代码

输入:训练实例集D,可选属性列表m_AttributeList,最小实例数m_MinInstances。
1.创建根节点m_root,m_AttributeList初始化为全0.
2.if D中的所有实例类标记都一样,则将m_root标记为叶节点并返回m_root。
3.if m_AttributeList为空,则将m_root标记为叶节点并返回m_root。
4.if m_Instances的实例数小于m_MinInstances,则将m_root标记为叶节点并返回m_root。
5.根据getBestSplitAttribute方法,在m_AttributeList中找出增益率最大的属性作为当前节点的分裂属性m_SplitAttribute。
6.根据m_SplitAttribute的属性值将当前节点的实例集m_Instances划分成多个子集,也就是当前节点的儿子节点 m_Sons。
7.对m_Sons中的每个节点递归地循环2-6的步骤以构建子树。
8.决策树构建好后,需要对其进行collapse处理和prune处理。前者是折叠子树过程,后者是后剪枝过程,具体说明在后面。

4.函数调用流程图

以下是各个部分的主要函数调用流程图,后面将对这些主要函数进行详细说明。


图1. buildClassifier部分的主要函数调用流程图

在MyC45类中的buildClassifier函数中,先初始化可选属性列表m_AttributeList和最小实例数m_MinInstances,然后用这两个参数初始化m_Root,之后再用m_Root调用MyClassifierTree类中的buildeClassifier函数即可。

这里写图片描述
图2. distributionForInstance部分的主要函数调用流程图

在MyC45类中的distributionForInstance函数中,直接用m_Root调用MyClassifierTree类中的distributionForInstance函数即可。

这里写图片描述
图3. 建树(buildTree)的主要函数调用流程图

建树首先根据isAllTheSame和canSplit两个函数和其他条件判断当前节点是否为叶节点;若可以分裂则调用getBestSplitAttribute函数求出最佳的分裂属性(也就是增益率最大的属性),getBestSplitAttribute函数中对每个可选属性分别调用beforeSplit、afterSplit和computeSplitInfo函数以求出该可选属性的增益率;随后将当前节点对应的实例集根据最佳分裂属性(即当前节点的分裂属性)分裂出多个子集,这些子集即为当前节点的子节点的实例集;然后对每个子集调用getNewTree函数以构建相应的子树,getNewTree函数中即调用buildTree函数,从而递归地构建决策树。

这里写图片描述
图4. 折叠子树(collapse)的主要函数调用流程图

collapse是对非叶节点进行操作,对于非叶节点,调用getCurrentTraningErrors求出当前节点的实例集上误分实例数,再调用getSubTraningErrors求出所有子树上实例集上误分实例数;如果后者大于前者,说明这些子树并不能提高这颗树的准确度,则把这些子树删除。否则在每个子树上递归的collapse。

这里写图片描述
图5. 后剪枝(prune)的主要函数调用流程图

prune是后剪枝操作,所以需要先递归找到叶节点的上一层的第一个子树开始剪枝。首先,调用getLargetBranch求出当前节点的最大树枝;随后,调用getEstimatedErrorsForBranch求出当前结点实例集在最大树枝上的误分实例数,标记为a;调用getEstimatedErrorsForDistribution求出当前节点的实例集在当前节点上的误分实例数,标记为b;调用getEstimatedErrors求出当前节点的所有子树上的误分实例数,标记为c。接着,判断是否应该把当前结点设置为叶节点,即第一个if语句。若不成立,则判断是否用最大树枝代替当前结点,即第二个if语句。若是,则将最大树枝上的变量信息覆盖当前结点的变量信息,并调用restInstances根据更改后的当前节点对当前实例集进行调整树的结构,并递归地对更改后的当前节点进行prune操作。

5.buildTree

isAllTheSame和canSplit函数比较简单,split就是根据属性值对实例集进行划分,getNewTree只是简单的初始化节点并调用buildTree形成递归,这些都比较简单,略过。

1.buildTree代码

    public void buildTree(Instances instances) throws Exception    {        initializePara(instances);//初始化一些变量        if (instances.numInstances()  <= m_MinInstances|| isAllTheSame(instances))         //小于m_MinInstances的实例集只能做叶节点,或当前实例集的类标记都一样也做叶节点,        {            m_IsLeaf = true;                   // 是则该节点为叶节点            m_Instances = instances;// 叶节点对应的实例集为instances            return;        }        if (!canSplit(m_AttributeList)) // 若可选属性集为空,则实例集不能继续分裂,所以当前节点是叶节点        {            m_IsLeaf = true;            m_Instances = instances;            return;        } else        {            int[] sonAttributeList = new int[m_NumAttributes];  //子节点的可选属性列表            for (int i = 0; i < sonAttributeList.length; i++)            {                if (i == m_ClassIndex)                    sonAttributeList[i] = -1;                else                    sonAttributeList[i] = m_AttributeList[i];            }            m_SplitAttribute = getBestSplitAttribute(instances, m_AttributeList); // 从当前可选的分裂属性集合中获取最佳的分裂属性            m_NameOfCurrentNode = instances.attribute(m_SplitAttribute).name(); // 获取当前分裂属性的名称            if (m_SplitAttribute != -1) // 当前实例集可以划分,且求得最佳分裂属性时            {                sonAttributeList[m_SplitAttribute] = -1; // 当前分裂属性在所有子节点上是不可选的,所以这里进行标记一下                int numOfSubTree = m_NumAttsValues[m_SplitAttribute];// 该节点对应的子节点数量,等于当前分裂属性的属性值个数                Instances[] localInstances;                localInstances = split(instances, numOfSubTree, m_SplitAttribute);// 根据分裂属性的属性值个数将实例集进行划分                m_NameOfLineToSon = new String[numOfSubTree];//每个子节点对应的属性值                for (int i = 0; i < numOfSubTree; i++)                    m_NameOfLineToSon[i] = m_Instances.attribute(m_SplitAttribute).value(i);                m_Sons = new MyCLassifierTree[numOfSubTree];                for (int i = 0; i < m_Sons.length; i++)                {// 接着为每一个localInstances构建子树                    m_Sons[i] = getNewTree(localInstances[i], sonAttributeList,m_MinInstances);                    localInstances[i] = null;                    if (m_Sons[i].m_IsLeaf)  //统计当前结点叶节点数                        m_NumLeaf ++;                }            } else// 当前实例集不可以划分,说明该节点是叶节点            {                m_IsLeaf = true;                m_Instances = instances;                return;            }        }    }

2.getBestSplitAttribute

    /**     * 从当前可选属性列表中求出最佳分裂属性,即增益率最大的属性     * @param instances     * @param attributesList     * @return     */    public int getBestSplitAttribute(Instances instances, int[] attributesList)    {        int bestSplitAttribute = 0; //标记最佳分裂属性        boolean canSplit = false; //判断该实例集是否可以继续划分        double[] gainRatio = new double[m_NumAttributes ]; //增益率,        double[] infoGain =  new double[m_NumAttributes ]; //信息增益,        double[] splitInfo =  new double[m_NumAttributes ]; //分裂信息,        for (int i = 0; i < m_NumAttributes ; i++) //遍历属性,计算每个属性增益率        {            if (i != m_ClassIndex && attributesList[i] != -1)//对可选的非类属性进行计算增益率            {                infoGain[i] = beforeSplit(instances) - afterSplit(instances,i);                splitInfo[i] = computeSplitInfo(instances,i);                gainRatio[i] = infoGain[i] / splitInfo[i];                canSplit = true; //当进入增益率计算时,说明该实例集可以进行划分            }            else             {                gainRatio[i] = 0;                infoGain[i] = 0;                splitInfo[i] = 0;            }        }        if (canSplit)        {  //若可以分裂,则找出gainRatio数组中最大且attributesList数组中不等于-1的下标,即为最佳分裂属性            bestSplitAttribute = getMaxIndex(gainRatio,attributesList);        }        else {//若当前实例集无法继续分裂,则返回-1作为没有找到最佳分裂属性的标记            bestSplitAttribute = -1;        }        return bestSplitAttribute;    }

getBestSplitAttribute是通过计算增益率求出的,以下下是计算增益率的一些公式:
(1)GainRito(A) = Gain(A)/SplitInfo(A)
GainRito(A):属性A的增益率。
Gain(A):属性A 的信息增益。
SplitInfo(A):属性A的分裂信息量。

(2)Gain(A) = Info(Insts) - Info(Insts,A)
Info(Insts):实例集Insts分裂前的信息量。
Info(Insts,A):实例集Insts根据属性A分裂后的信息量。

(3)这里写图片描述
C:实例集Insts中的类标记个数。
Pi:第i种类标记对应的实例数与实例总数的比值。

(4)这里写图片描述
nA:属性A的属性值个数。
n:实例集的实例总数。
ni:属性A的第i种属性值对应的实例数。
Instsi:属性A的第i种属性值对应的实例集。
Info(Instsi):根据公式(3)计算属性A的第i种属性值对应的实例集的信息量。

(5)这里写图片描述

3.beforeSplit

    /**     * 计算分裂前实例集的信息值     * @param instances     * @return     */    public double beforeSplit(Instances instances)    {        double infoBeforeSplit = 0;        int numClasses = instances.numClasses();        int numInstances = instances.numInstances(); //实例总数        double allWeight = 0;                                                 //实例集instances中的实例数(权重之和)        double[] numInstancesInClass = new double[numClasses];//每个类标记对应的实例数(权重)        for (int i = 0; i < numInstances; i++) //遍历每个实例,统计出每个类标记对应的实例数(权重)        {            int classLable = (int)instances.instance(i).classValue();            numInstancesInClass[classLable] += instances.instance(i).weight();             allWeight += instances.instance(i).weight();        }        if (onlyOneNotZero(numInstancesInClass,allWeight))//1.如果只有一个类标记的实例数(权重)不为0(其他类标记的实例数(权重)为0),则信息值为0            return 0.0;        if (eachEqualAve(numInstancesInClass))//2.如果所有类标记对应的实例数(权重)相等,则信息值最大,这里设置为1            return 1.0;        for (int i = 0; i < numClasses; i++) //3.根据公式(3)求出实例集的信息值        {            double pi = numInstancesInClass[i] / allWeight;            if (pi != 0.0) //注意pi为0时不要纳入计算,因为log0是一个无效值,这会导致整个infoBeforeSplit值无效(NaN)。反正pi等于0时 pi * log2(pi)即为0,所以不纳入计算即可                infoBeforeSplit = infoBeforeSplit + pi * log2(pi) ;        }        return - infoBeforeSplit; //注意加个负号    }

4.afterSplit

    /**     * 计算实例集根据属性attribute进行分裂后的信息量     * @param instances     * @param attribute     * @return     */    public double afterSplit(Instances instances, int attribute)    {        double infoAfterSplit = 0;        int numAttributeValue = instances.attribute(attribute).numValues(); //属性attribute的属性值个数        int numInstances = instances.numInstances(); // 实例总数        double allWeight = 0; ////实例集instances中的实例数权重之和        double[] weightInAttValue = new double[numAttributeValue];//每个属性值对应的实例数(权重之和)        Instances[] instsOfValue = new Instances[numAttributeValue];//每个属性值对应的实例子集        for (int i = 0; i < instsOfValue.length; i++)//初始化            instsOfValue[i] = new Instances(instances, 0);        for (int i = 0; i < numInstances; i++)//遍历实例集,将实例集instances根据属性值划分实例集,统计每个实例集的实例数(权重)        {            int attValue = (int)instances.instance(i).value(attribute); //获取第i个实例在属性attribute中的属性值            weightInAttValue[attValue] += instances.instance(i).weight(); //计算每个属性值对应的实例数(权重之和)            allWeight +=  instances.instance(i).weight();  //计算实例集instances中的实例权重之和            instsOfValue[attValue].add(instances.instance(i)); //将实例i放入对应的实例子集之中        }        for (int i = 0; i < numAttributeValue; i++)//根据公式(4)计算根据属性i分裂后的实例集的信息值        {            double value = weightInAttValue[i]/allWeight * beforeSplit(instsOfValue[i]);            infoAfterSplit = infoAfterSplit + value;        }        return infoAfterSplit;    }}

5.computeSplitInfo

    /**     * 计算分裂信息量     * @param instances     * @param attribute     * @return     */    public double computeSplitInfo(Instances instances, int attribute)    {        double splitInfo = 0;        double allWeight = 0; ////实例集instances中的实例数(权重之和)        int numAttributeValue = instances.attribute(attribute).numValues(); //属性attribute的属性值个数        double[] weightInEachValue = new double[numAttributeValue]; //每个属性值对应的实例数(权重之和)        for (int i = 0; i < instances.numInstances(); i++)        {            int  attValue = (int)instances.instance(i).value(attribute); //获取第i个实例在属性attribute中的属性值            weightInEachValue[attValue] += instances.instance(i).weight(); //计算每个属性值对应的实例数(权重之和)            allWeight += instances.instance(i).weight(); //计算实例集instances中的实例数(权重之和)        }        for (int i = 0; i < numAttributeValue; i++) //根据公式(5)计算分裂信息值        {            double pi = weightInEachValue[i] / allWeight;            if (pi != 0)//注意pi为0时不要纳入计算,因为log0是一个无效值,这会导致整个splitInfo值无效。                splitInfo = splitInfo +pi* log2(pi);        }        return  - splitInfo; //注意加个负号    }

6.collapse

1.collapse

    /**     * Collapses a tree to a node if training error doesn't increase.     * 如果当前节点存在很多子节点,但这些子节点并不能提高这颗分类树的准确度,则把这些孩子节点删除。否则在每个孩子上递归的collapse。     * 通过collapse方法可以在不减少精度的前提下减少决策树的深度,进而提高效率。     */    public final void collapse( )    {        double errorsOfTree;       // 当前节点上训练实例集误分的实例数(权重)        double errorsOfSubtree; // 当前节点的所有子树上训练实例集误分的实例数(权重)        int i;        if (!m_IsLeaf)//只有对非叶节点才进行折叠子树操作        {               errorsOfTree = getCurrentTrainingErrors();              errorsOfSubtree = getSubTrainingErrors();            if (errorsOfSubtree >= errorsOfTree - 1E-3)                //所有子树上误分实例数(权重)大于当前节点误分实例数(权重)时,说明这些孩子节点不好,将他们删除。                //删除的方式是将当前节点 的子树变量设置为空并将该节点设置为叶节点。1E-3是10的-3次方,即0.001             {                m_Sons = null;                m_IsLeaf = true;            }            else                for (i = 0; i < m_Sons.length; i++) // 在每个孩子上递归地进行折叠子树操作                    m_Sons[i].collapse();        }    }

2.getCurrentTrainingErrors

    /**     * 计算当前结点的误分实例数     * @return     */    private double getCurrentTrainingErrors()    {        double wrongWeight = 0;        int majorityClassLable = majorityClassLable(m_Instances); //获取当前实例集中的多数类        for (int i = 0; i < m_Instances.numInstances(); i++)  //遍历当前实例集,求出总误分实例数(权重)        {            int classValue = (int)m_Instances.instance(i).classValue();            if (classValue != majorityClassLable)            {                m_Predictions[i] = -1;                wrongWeight += m_Instances.instance(i).weight();            }        }        return wrongWeight;    }

3.getSubTrainingErrors

    /**     * 计算当前结点的所有子节点中误分的实例数(权重)     * @return     */    private double getSubTrainingErrors()    {        double errors = 0;        if (m_IsLeaf)   //对叶节点,直接调用getCurrentTrainingErrors函数求出该叶节点上的误分实例数(权重)            return getCurrentTrainingErrors();        else //对非叶节点,递归调用getSubTrainingErrors以求出所有子节点上的误分实例数(权重)        {            for (int i = 0; i < m_Sons.length; i++)                errors = errors + m_Sons[i].getSubTrainingErrors();            return errors;        }    }

7.prune

1.prune

    /**     * 后剪枝操作     * @throws Exception     */    public final void prune( ) throws Exception    {        double errorsLargestBranch; //当前节点实例集在最大树枝上的误分实例数        double errorsLeaf;                     //假设当前节点是叶节点时,该节点对应的实例集在该节点上的误分实例数        double errorsTree;                     //计算当前节点的所有子树上的误分实例数        int indexOfLargestBranch;     //最大树枝的下标        MyCLassifierTree largestBranch;  //临时保存最大树枝        int i;        if (!m_IsLeaf) //对非叶节点均进行剪枝        {            for (i = 0; i < m_Sons.length; i++)// 对当前节点的子节点递归地进行剪枝,由于是后剪枝,所以从树的最底层开始往上                m_Sons[i].prune();            // 求出当前树上的最大树枝,即当前节点的所有子集中实例数最大的子集下标,            indexOfLargestBranch = getLargetBranch();            // 计算当前节点实例集在最大树枝上的误分实例数            errorsLargestBranch = m_Sons[indexOfLargestBranch].getEstimatedErrorsForBranch(m_Instances);            //计算当前节点对应的实例集在当前节点上的误分实例数            errorsLeaf = getEstimatedErrorsForDistribution(m_Instances);            // 计算当前节点的所有子树上的误分实例数            errorsTree = getEstimatedErrors();            // 判断将该节点设置为叶节点是不是最好的选择,            if (Utils.smOrEq(errorsLeaf, errorsTree + 0.1) && Utils.smOrEq(errorsLeaf, errorsLargestBranch + 0.1))            {                m_Sons = null;                m_IsLeaf = true;                return;            }            // 判断用最大树枝代替当前节点是不是最好的选择            if (Utils.smOrEq(errorsLargestBranch, errorsTree + 0.1))            {                largestBranch = m_Sons[indexOfLargestBranch];  //获取最大树枝                m_Sons = largestBranch.m_Sons;                   m_AttributeList = largestBranch.m_AttributeList;   //将最大树枝的可选分裂属性列表覆盖当前节点的可选分裂属性列表                m_AttributeList[m_SplitAttribute] = 0;                        //由于会用最大树枝的分裂属性代替了原先节点的分裂属性,所以原先节点的分裂属性处于可选状态                m_SplitAttribute = largestBranch.m_SplitAttribute;  //用最大树枝的分裂属性代替了原先节点的分裂属性                m_IsLeaf = largestBranch.m_IsLeaf;                resetInstances(m_Instances);  //将当前实例集根据修改后的分裂属性进行划分                prune(); //递归地对修改后的节点进行剪枝            }        }    }

2.getEstimatedErrorsForDistribution

    /**     * 求出testInstances实例集在以m_Instances为根据的分类器中的误分实例数(权重)     * @param theDistribution     *            the distribution to use     * @return the estimated errors     */    private double getEstimatedErrorsForDistribution(Instances testInstances)    {        if (Utils.eq(testInstances.numInstances(), 0)) //若testInstances实例数为0,则误分实例数只能为0            return 0;        else        {            double inCorrectWeight = 0;            double allWeight = 0.0;            int majorityClassLable ;            majorityClassLable = majorityClassLable(m_Instances);//求出当前实例集m_Instances中的多数类            for (int i = 0; i < testInstances.numInstances(); i++)            {                allWeight += testInstances.instance(i).weight(); //统计测试实例集testInstances中的实例总数(权重)                int classVlaue = (int)testInstances.instance(i).classValue();                if (classVlaue != majorityClassLable)                    inCorrectWeight += testInstances.instance(i).weight(); //统计测试实例集testInstances中误分实例的实例总数(权重)            }            return inCorrectWeight +  Stats.addErrs(allWeight,inCorrectWeight, 0.25f);         }    }

3.getEstimatedErrorsForBranch

    /**     *  求出testInstances实例集在以m_Instances为根据的分类器中的误分实例数(权重)     * @param data     *            the data to work with     * @return the estimated errors     * @throws Exception     *             if something goes wrong     */    private double getEstimatedErrorsForBranch(Instances testInstances) throws Exception    {        double errors = 0;        int i;        if (m_IsLeaf) //若当前节点是叶节点,则调用getEstimatedErrorsForDistribution求出当前节点上的误分实例数            return getEstimatedErrorsForDistribution(testInstances);        else //若当前节点不是叶节点,则计算testInstances在其所有子节点上的误分实例数之和        {            //将testInstances根据当前节点的分裂属性的属性值,划分成不同的测试实例集            int numSubset = testInstances.attribute(m_SplitAttribute).numValues();            Instances[] localInstances = split(testInstances, numSubset, m_SplitAttribute);//测试实例子集            for (i = 0; i < m_Sons.length; i++) //计算每个测试实例子集在对应子节点上的误分实例数(权重)                errors = errors + m_Sons[i].getEstimatedErrorsForBranch(localInstances[i]);            return errors;        }    }

4.getEstimatedErrors

    /**     *计算当前结点的所有子树上的误分实例数(权重)     * @return the estimated errors     */    private double getEstimatedErrors()    {        double errors = 0;        int i;        if (m_IsLeaf)  //若当前结点是叶节点,则直接计算误分实例树(权重)            return getEstimatedErrorsForDistribution(m_Instances);        else        {            for (i = 0; i < m_Sons.length; i++)  //若当前节点不是叶节点,则递归地计算其所有子树上的误分实数(权重)                errors = errors + m_Sons[i].getEstimatedErrors();            return errors;        }    }

5.resetInstances

    /**     * 将当前实例集根据修改后的分裂属性进行划分,并修改预测结果数组m_Prediction[]     * @param instances     * @throws Exception      */    private void resetInstances(Instances instances) throws Exception    {        m_Instances = instances;        if (!m_IsLeaf) //若当前节点不是叶子节点,则递归地对其及其子树进行重新划分实例集        {            int numSubset = (int)instances.attribute(m_SplitAttribute).numValues();            Instances[] localInstances = split(instances, numSubset, m_SplitAttribute);            for (int i = 0; i < m_Sons.length; i++)//递归地对其子树进行重新划分实例集            {                m_Sons[i].m_Instances = localInstances[i];                   m_Sons[i].resetInstances(localInstances[i]);            }        }        else        {//由于实例集发生了变化,所以需要根据新实例集和新的可选分裂属性列表构建树            m_AttributeList[m_SplitAttribute] = -1;  //在当前节点上建树,则当前分裂属性在其子树的构建过程中不可选            m_IsLeaf = false;            int numSubset = (int)instances.attribute(m_SplitAttribute).numValues();   //此时是在各个子节点上建树            Instances[] localInstances = split(instances, numSubset, m_SplitAttribute);            m_Sons = new MyCLassifierTree[numSubset];            for (int i = 0; i < localInstances.length; i++)            {                m_Sons[i] = getNewTree(localInstances[i], m_AttributeList, m_MinInstances);            }        }    }