Weka算法Classifier-trees-RandomTree源码分析

来源:互联网 发布:nba各项数据历史记录 编辑:程序博客网 时间:2024/06/05 08:38


一、RandomTree算法

在网上搜了一下,并没有找到RandomTree的严格意义上的算法描述,因此我觉得RandomTree充其量只是一种构建树的思路,和普通决策树相比,RandomTree会随机的选择若干属性来进行构建而不是选取所有的属性。

Weka在实现上,对于随机属性的选取、生成分裂点的过程是这样的:

1、设置一个要选取的属性的数量K

2、在全域属性中无放回的对属性进行抽样

3、算出该属性的信息增益(注意不是信息增益率)

4、重复K次,选出信息增益最大的当分裂节点。

5、构建该节点的孩子子树。


二、具体代码实现

(1)buildClassifier

 public void buildClassifier(Instances data) throws Exception {    // 如果传入的K不合理,把K放到一个合理的范围里    if (m_KValue > data.numAttributes() - 1)      m_KValue = data.numAttributes() - 1;    if (m_KValue < 1)      m_KValue = (int) Utils.log2(data.numAttributes()) + 1;//这个是K的默认值    // 判断一下该分类器是否有能力处理这个数据集,如果没能力直接就在testWithFail里抛异常退出了    getCapabilities().testWithFail(data);    // 删除掉missClass    data = new Instances(data);    data.deleteWithMissingClass();    // 如果只有一列,就build一个ZeroR模型,之后就结束了。ZeroR模型分类是这样的:如果是连续型,总是返回期望,如果离散型,总是返回训练集中出现最多的那个    if (data.numAttributes() == 1) {      System.err          .println("Cannot build model (only class attribute present in data!), "              + "using ZeroR model instead!");      m_zeroR = new weka.classifiers.rules.ZeroR();      m_zeroR.buildClassifier(data);      return;    } else {      m_zeroR = null;    }    // 如果m_NumFlods大于0,则会把数据集分为两部分,一部分用于train,一部分用于test,也就是backfit
    //分的方式和多折交叉验证是一样的,例如m_NumFlods是10的话,则train占90%,backfit占10%    Instances train = null;    Instances backfit = null;    Random rand = data.getRandomNumberGenerator(m_randomSeed);    if (m_NumFolds <= 0) {      train = data;    } else {      data.randomize(rand);      data.stratify(m_NumFolds);      train = data.trainCV(m_NumFolds, 1, rand);      backfit = data.testCV(m_NumFolds, 1);    }    // 生成所有的可选属性    int[] attIndicesWindow = new int[data.numAttributes() - 1];    int j = 0;    for (int i = 0; i < attIndicesWindow.length; i++) {      if (j == data.classIndex())        j++; // 忽略掉classIndex      attIndicesWindow[i] = j++;//这段代码有点奇怪,i和j是相等的,为啥不用attIndicesWindow=i?    }    // 算出每个class的频率,也就是每个分类出现的次数(更正确的说法应该是权重,但权重默认都是1)    double[] classProbs = new double[train.numClasses()];    for (int i = 0; i < train.numInstances(); i++) {      Instance inst = train.instance(i);      classProbs[(int) inst.classValue()] += inst.weight();    }    // Build tree    m_Tree = new Tree();    m_Info = new Instances(data, 0);    m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0);//调用tree的build方法,在后面单独分析    // Backfit if required    if (backfit != null) {      m_Tree.backfitData(backfit);//在后面单独分析    }  }

这个Tree对象是RandomTree的一个子类,之前我还以为会复用其余的决策树模型(比如J48),但weka没这么做,很惊奇的是RandomTree和J48的作者还是同一个,不知道为啥这么设计。


(2)tree.buildTree

 protected void buildTree(Instances data, double[] classProbs,        int[] attIndicesWindow, Random random, int depth) throws Exception {      //首先判断一下是否有instance,如果没有的话直接就返回      if (data.numInstances() == 0) {        m_Attribute = -1;        m_ClassDistribution = null;        m_Prop = null;        return;      }      m_ClassDistribution = classProbs.clone();      if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum          || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],              Utils.sum(m_ClassDistribution))          || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {        // 递归结束的条件有3个 1、instance数量小于2*m_Minnum  2、instance都已经在同一个类中 3、达到最大的深度
       //前两个条件和j48的递归结束条件很相似,相关内容可参考我之前的几篇博客。        m_Attribute = -1;        m_Prop = null;        return;      }      double val = -Double.MAX_VALUE;      double split = -Double.MAX_VALUE;      double[][] bestDists = null;      double[] bestProps = null;      int bestIndex = 0;      double[][] props = new double[1][0];      double[][][] dists = new double[1][0][0];//这个数组第一列只有下标为0的被用到,不知道为啥这么设计      int attIndex = 0;//存储被选择到的属性      int windowSize = attIndicesWindow.length;//存储目前可选择的属性的数量      int k = m_KValue;//k代表还能选择的属性的数量      boolean gainFound = false;//是否发现了一个有信息增益的节点      while ((windowSize > 0) && (k-- > 0 || !gainFound)) {//此循环退出条件有2个 1、没有节点可以选了 2、已经选了k个属性了并且找到了一个有用的属性 换句话说,如果K次迭代没有找到可以分裂的随机节点,循环也会继续下去
        int chosenIndex = random.nextInt(windowSize);//随机选一个,生成下标        attIndex = attIndicesWindow[chosenIndex];//得到该属性的index       //下面三行把选择的属性放到attIndicesWindow的末尾,然后把windowSize-1这样下个循环就不会选到了,也就是实现了无放回的抽取        attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];        attIndicesWindow[windowSize - 1] = attIndex;        windowSize--;        double currSplit = distribution(props, dists, attIndex, data);//这个函数计算了在使用attIndex进行分裂所产生的分布,如果classIndex是连续值的话,还计算了分裂点,原理和J48的split一样,不在赘述。        double currVal = gain(dists[0], priorVal(dists[0]));//这个计算了信息增益        if (Utils.gr(currVal, 0))          gainFound = true;//如果信息增益大于0的话,说明节点有效,设置gainFound        if ((currVal > val) || ((currVal == val) && (attIndex < bestIndex))) {          val = currVal; //如果信息增益大的话,则更新把attIndex更新为bestIndex,这是为了选取最优的节点(ID3)的方法          bestIndex = attIndex;          split = currSplit;          bestProps = props[0];          bestDists = dists[0];        }      }      m_Attribute = bestIndex;      // Any useful split found?      if (Utils.gr(val, 0)) { <span style="white-space:pre"></span>//如果找到了一个分裂点,则在该分裂点的基础上构建子树
        m_SplitPoint = split;        m_Prop = bestProps;        Instances[] subsets = splitData(data);        m_Successors = new Tree[bestDists.length];        for (int i = 0; i < bestDists.length; i++) {          m_Successors[i] = new Tree();          m_Successors[i].buildTree(subsets[i], bestDists[i], attIndicesWindow,              random, depth + 1);//注意这里传入的attIndicesWindow没有变,换句话说,每次迭代传入的可选属性集合是一样的,因此子节点在进行属性的random选择时,很有可能会选择到父节点已经选过的节点,但因为不产生信息增益,因此不会再次作为bestIndex,但会产生额外的计算量(我感觉还不少),这里还有一定的优化空间,同理j48也是这么实现的。        }        boolean emptySuccessor = false;        for (int i = 0; i < subsets.length; i++) {          if (m_Successors[i].m_ClassDistribution == null) {            emptySuccessor = true;            break;          }        }        if (!emptySuccessor) {          m_ClassDistribution = null;        }      } else {       //这个else是<span style="font-family: Arial, Helvetica, sans-serif;">Utils.gr(currVal, 0)这个条件的,代表没有选择到合适的分裂节点</span>        m_Attribute = -1;      }    }

(3)tree.backfit

什么是Backfit?Backfit将改变已有tree节点及其子节点的class分布,而class分布将直接被用于实例的预测。

直接使用RandomTree有时会出现过拟合的现象(通过代码可以看到,和J48相比没有剪枝过程),因此通过传入一个新的数据集来backfit已有节点是一个解决过拟合的方法。

  protected void backfitData(Instances data, double[] classProbs)        throws Exception {
<span style="white-space:pre"></span>//判断一下是否有数据      if (data.numInstances() == 0) {        m_Attribute = -1;        m_ClassDistribution = null;        m_Prop = null;        return;      }      m_ClassDistribution = classProbs.clone();      if (m_Attribute > -1) {        // m_Attribut>-1代表不是leaf,可以看上面的buildTree得出这个结论        m_Prop = new double[m_Successors.length];//子节点数组的length也就是分类的类的数量
<span style="white-space:pre"></span>//把传入的data用此节点算各类的频率        for (int i = 0; i < data.numInstances(); i++) {          Instance inst = data.instance(i);          if (!inst.isMissing(m_Attribute)) {            if (data.attribute(m_Attribute).isNominal()) {              m_Prop[(int) inst.value(m_Attribute)] += inst.weight();            } else {              m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst                  .weight();//连续型只会分两类,小于splitPoint一类,大于是一类,和J48采用的策略相同            }          }        }        if (Utils.sum(m_Prop) <= 0) {          m_Attribute = -1;//如果data全部都是missingValue,则把此节点变成leaf节点          m_Prop = null;          return;        }        // 归一化        Utils.normalize(m_Prop);        // 根据本节点算出在data上进行分类的subset        Instances[] subsets = splitData(data);        for (int i = 0; i < subsets.length; i++) {          // 递归的对孩子节点进行backfit          double[] dist = new double[data.numClasses()];          for (int j = 0; j < subsets[i].numInstances(); j++) {            dist[(int) subsets[i].instance(j).classValue()] += subsets[i]                .instance(j).weight();          }          m_Successors[i].backfitData(subsets[i], dist);        }<span style="white-space:pre"></span>        if (getAllowUnclassifiedInstances()) {          m_ClassDistribution = null;          return;        }<span style="white-space:pre"></span>//如果某个子节点的分布为空的话,则父节点要保存分布,否则不需要持有分布。
 <span style="white-space:pre"></span>//为什么呢?因为使用RandomTree进行预测时会遍历节点的分布并进行累加,得到分布最大的class作为预测class,在J48的那篇博客中有分析        boolean emptySuccessor = false;        for (int i = 0; i < subsets.length; i++) {          if (m_Successors[i].m_ClassDistribution == null) {            emptySuccessor = true;            return;          }        }        m_ClassDistribution = null;      }    }


三、总结

对RandomForest的分析到这里就结束了,首先分析了RandomForest,接着分析了Bagging,最后分析了RandomTree。



0 0
原创粉丝点击