weka[3] - J48(二)

来源:互联网 发布:淘宝上日本代购哪家好 编辑:程序博客网 时间:2024/06/07 03:11

J48(一)中,主要分析了分裂的策略:二叉和多叉

这一节,主要看看源码中,关于剪枝的部分。主要看PruneableClassifierTree。

buildClassify:

 public void buildClassifier(Instances data)        throws Exception {    // can classifier tree handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();       Random random = new Random(m_seed);   //这个看的不是太懂啊!!感觉就是随机化一下   data.stratify(numSets);   //trainCV:表示分成numSets后,返回前numSets-1个,然后样本随机排序   //testCV: 表示分成numSets后,返回第numSets-1个   buildTree(data.trainCV(numSets, numSets - 1, random),     data.testCV(numSets, numSets - 1), !m_cleanup);   if (pruneTheTree) {     prune();   }   if (m_cleanup) {     cleanup(new Instances(data, 0));   }  }
这里主要就是buildTree,这个函数在classifierTree中。(PruneableClassifierTree继承自ClassifierTree)

还有一个是prune(),待会会分析,就是如何剪枝。

clearnup就是节省内存,把不需要的删掉。这里只是删掉数据的信息。

先看prune,再看buildTree。

prune:

public void prune() throws Exception {      if (!m_isLeaf) {            // Prune all subtrees.      for (int i = 0; i < m_sons.length; i++)son(i).prune();            // Decide if leaf is best choice.      if (Utils.smOrEq(errorsForLeaf(),errorsForTree())) {// Free son Treesm_sons = null;m_isLeaf = true;// Get NoSplit Model for node.m_localModel = new NoSplit(localModel().distribution());      }    }  }

prune的函数代码很简洁,实际上确很复杂。我可能也无法在这里完全说清楚,但是尽力吧。

首先我们知道,如果是此时的node是leaf,那么自然不用剪枝。

如果不是leaf,那么需要检验他的子节点是否需要剪枝(递归下去)。

这里主要要分析下errorsForleaf()和errorsForTree。这是剪枝的关键。这里是说如果我叶子的错误率,小于本节点的错误率,那么我把本节点当成叶子!

来看看errorsForleaf和errorsForTree:

  private double errorsForLeaf() throws Exception {    return m_test.total()-      m_test.perClass(localModel().distribution().maxClass());  }
localModel()是ClassifysplitModel这个类,distribution其实就是表示这个节点的一些统计量,包括perClass,perBag,perClassperBag。

然后这个return的意思其实很简单,就是说,我在这个节点,全部的weight - 正确的weight = 错误的weight。就这么简单。

  private double errorsForTree() throws Exception {    double errors = 0;    if (m_isLeaf)      return errorsForLeaf();    else{      for (int i = 0; i < m_sons.length; i++)        //如果training中没有样本落入bag(i)中if (Utils.eq(localModel().distribution().perBag(i), 0)) {          //那么剪枝的错误率=testing中bag(i)的权重-落入bag(i)的样本最大class的权重!          //这个意思就是说 全部 - 正确 = 错误的部分  errors += m_test.perBag(i)-    m_test.perClassPerBag(i,localModel().distribution().maxClass());} else  errors += son(i).errorsForTree();      return errors;    }  }
这里其实也差不多,他是递归地计算,如果不是叶子,那么errors是所有子节点错误权重的累计。记住,这个时候,节点保存的m_test是testing data的数据,而localModel都是training的数据。(这个后面关于buildTree的时候,会提到)

prune基本就是这些。接下来看看重头戏,buildTree这个函数(在ClassifierTree中)

buildTree:

public void buildTree(Instances train, Instances test, boolean keepData)       throws Exception {        Instances [] localTrain, localTest;    int i;        if (keepData) {      m_train = train;    }    m_isLeaf = false;    m_isEmpty = false;    m_sons = null;    //返回的是splitModel    m_localModel = m_toSelectModel.selectModel(train, test);    //返回distribution    m_test = new Distribution(test, m_localModel);    if (m_localModel.numSubsets() > 1) {      //将train和test分成num_Subsets份      localTrain = m_localModel.split(train);      localTest = m_localModel.split(test);      train = test = null;      m_sons = new ClassifierTree [m_localModel.numSubsets()];      for (i=0;i<m_sons.length;i++) {        //生成一个子树m_sons[i] = getNewTree(localTrain[i], localTest[i]);localTrain[i] = null;localTest[i] = null;      }    }else{      m_isLeaf = true;      if (Utils.eq(train.sumOfWeights(), 0))m_isEmpty = true;      train = test = null;    }  }
这段代码有些地方让我挺困惑的。主要还是不熟悉整个类的调用、继承情况吧。慢慢分析!

前面都好懂,这行可能会卡一下。 

 m_localModel = m_toSelectModel.selectModel(train, test);

m_toSelectModel.selectModel就是返回一个splitModel(Bin或者C45)。但是要注意,这里调用的是具体的某个modelSelection(比如Bin)的seletModel()函数,实际返回的只有train的信息,test过滤了!

下面这行要具体分析:

m_test = new Distribution(test, m_localModel);

public Distribution(Instances source,       ClassifierSplitModel modelToUse)       throws Exception {    int index;    Instance instance;    double[] weights;        //weight per class per bag。    m_perClassPerBag = new double [modelToUse.numSubsets()][0];    m_perBag = new double [modelToUse.numSubsets()];    totaL = 0;    m_perClass = new double [source.numClasses()];    for (int i = 0; i < modelToUse.numSubsets(); i++)      m_perClassPerBag[i] = new double [source.numClasses()];    Enumeration enu = source.enumerateInstances();    while (enu.hasMoreElements()) {      instance = (Instance) enu.nextElement();      //判断instance属于哪个bag。这里bag是根据属性个数分的      index = modelToUse.whichSubset(instance);      if (index != -1)add(index, instance);      else {weights = modelToUse.weights(instance);addWeights(instance, weights);      }    }  }

这里主要是根据train的结果,然后计算test数据的distribution。train的结果主要就是分多少个bag,split point是哪些。

buildTree后面的代码都很自然了~ 

总的来说,这个PruneableClassifierTree,就是将数据分成train和test(用于计算错误率),然后自顶向下递归计算每个节点的错误率,和如果此时把此节点砍成叶子节点的错误率,来比较大小。如果前者小,那么就进行剪枝!


还有一种C45PruneableClassifierTree。这个和上面的剪枝稍微有些区别。他没有分成train和test去计算误差。而是直接用train的数据计算一个estimated error(加入了置信度)。具体代码限于篇幅就不贴了。


0 0
原创粉丝点击