Weka算法Classifier-tree-J48源码分析(二)ClassifierTree

来源:互联网 发布:大数据分析常用模型 编辑:程序博客网 时间:2024/05/21 05:41

一、问题

主要带着四个问题去研究J48的实现。

1、如何控制分类树的精度。

2、如何处理缺失的值(MissingValue)

3、如何对连续值进行离散化。

4、如何进行分类树的剪枝。


二、BuildClassifier

每一个分类器都会实现这个方法,传入一个Instances对象,在这个对象基础上进行来构建分类树。核心代码如下:

[java] view plaincopy
  1. public void buildClassifier(Instances instances)   
  2.        throws Exception {  
  3.   
  4.     ModelSelection modSelection;       
  5.   
  6.     if (m_binarySplits)  
  7.       modSelection = new BinC45ModelSelection(m_minNumObj, instances);  
  8.     else  
  9.       modSelection = new C45ModelSelection(m_minNumObj, instances);  
  10.     if (!m_reducedErrorPruning)  
  11.       m_root = new C45PruneableClassifierTree(modSelection, !m_unpruned, m_CF,  
  12.                         m_subtreeRaising, !m_noCleanup);  
  13.     else  
  14.       m_root = new PruneableClassifierTree(modSelection, !m_unpruned, m_numFolds,  
  15.                        !m_noCleanup, m_Seed);  
  16.     m_root.buildClassifier(instances);  
  17.     if (m_binarySplits) {  
  18.       ((BinC45ModelSelection)modSelection).cleanup();  
  19.     } else {  
  20.       ((C45ModelSelection)modSelection).cleanup();  
  21.     }  
  22.   }  
可以看到这段代码逻辑非常清楚,首先根据是否是一个二分树(即每个节点只有是否两种选择)来构造一个ModelSelection,随后根据是否有m_reduceErrorPruning标志来构造相应的ClassifierTree,在这个tree上真正的构建模型,最后清理数据(主要是做释放指针的工作,防止Tree持有Instances指针导致GC不能在上层调用者想释放Instances的时候进行释放)。


三、C45PruneableClassifierTree

(1)该类也实现了BuildCClassifier方法来构建分类器,先看一下这个方法的主逻辑,代码如下:

[java] view plaincopy
  1. public void buildClassifier(Instances data) throws Exception {  
  2.   
  3.   // can classifier tree handle the data?  
  4.   getCapabilities().testWithFail(data);  
  5.   
  6.   // remove instances with missing class  
  7.   data = new Instances(data);  
  8.   data.deleteWithMissingClass();  
  9.     
  10.  buildTree(data, m_subtreeRaising || !m_cleanup);  
  11.  collapse();  
  12.  if (m_pruneTheTree) {  
  13.    prune();  
  14.  }  
  15.  if (m_cleanup) {  
  16.    cleanup(new Instances(data, 0));  
  17.  }  
  18. }  
首先testWithFail是检测一下传入的data是否能用该分类器进行分类,比如C45只能对要分类的属性的取值是离散值的Instances进行分类,这个test就是检测诸如此类的逻辑。

接着清理一下instances里面的无效行(相应分类属性为空的行)。

在此数据上调用buildTree进行构建分类树。

调用collapse()进行树的“坍塌”(这里我不太知道学名应该怎么翻译)

如果有需要,则进行prune()剪枝。

最后清理数据。

(2)按照这个顺序首先来看buildTree函数

[html] view plaincopy
  1. public void buildTree(Instances data, boolean keepData) throws Exception {  
  2.      
  3.    Instances [] localInstances;  
  4.   
  5.    if (keepData) {  
  6.      m_train = data;  
  7.    }  
  8.    m_test = null;  
  9.    m_isLeaf = false;  
  10.    m_isEmpty = false;  
  11.    m_sons = null;  
  12.    m_localModel = m_toSelectModel.selectModel(data);  
  13.    if (m_localModel.numSubsets() > 1) {  
  14.      localInstances = m_localModel.split(data);  
  15.      data = null;  
  16.      m_sons = new ClassifierTree [m_localModel.numSubsets()];  
  17.      for (int i = 0; i < m_sons.length; i++) {  
  18. m_sons[i] = getNewTree(localInstances[i]);  
  19. localInstances[i] = null;  
  20.      }  
  21.    }else{  
  22.      m_isLeaf = true;  
  23.      if (Utils.eq(data.sumOfWeights(), 0))  
  24. m_isEmpty = true;  
  25.      data = null;  
  26.    }  
  27.  }  
该函数逻辑也比较简单(怎么都比较简单?!),首先根据传入参数来判断是否应该持有数据。

然后根据m_toSelectModel来选择一个模型并把传入的数据集按相应的规则分成不同的subSet,这个selectModel是构造函数传入的,参见刚才描述的主流程。这一步如果对应上篇博客的算法描述,得到的subSet就是第10行的dv。

接着判断subSet的数量,如果只有一个,那么就是一个叶子节点,什么都不用做就返回了。

否则根据localModel将data分成不同的subInstances,接着为每一个subInstances建立新的ClassifierTree节点作为自己的孩子节点,并调用getNewTree函数来为每一个subInstances构造新的tree。

(3)采用DFS的方式接着去看一下getNewTree的逻辑

[java] view plaincopy
  1. protected ClassifierTree getNewTree(Instances data) throws Exception {  
  2.   
  3.   ClassifierTree newTree = new ClassifierTree(m_toSelectModel);  
  4.   newTree.buildTree(data, false);  
  5.     
  6.   return newTree;  
  7. }  
很简单,就是一个递归调用。

(4)重新回到C45PruneableClassifierTree.buildClassifier方法,来研究一下其中的collapse函数。

[java] view plaincopy
  1. /** 
  2.    * Collapses a tree to a node if training error doesn't increase. 
  3.    */  
  4.   public final void collapse(){  
  5.   
  6.     double errorsOfSubtree;  
  7.     double errorsOfTree;  
  8.     int i;  
  9.   
  10.     if (!m_isLeaf){  
  11.       errorsOfSubtree = getTrainingErrors();  
  12.       errorsOfTree = localModel().distribution().numIncorrect();  
  13.       if (errorsOfSubtree >= errorsOfTree-1E-3){  
  14.   
  15.     // Free adjacent trees  
  16.     m_sons = null;  
  17.     m_isLeaf = true;  
  18.               
  19.     // Get NoSplit Model for tree.  
  20.     m_localModel = new NoSplit(localModel().distribution());  
  21.       }else  
  22.     for (i=0;i<m_sons.length;i++)  
  23.       son(i).collapse();  
  24.     }  
  25.   }  
通过注释也可以看出,如果该节点的存在很多孩子节点,但这些孩子节点并不能提高这颗分类树的准确度,则把这些孩子节点删除。否则在每个孩子上递归的坍塌。通过collapse方法可以在不减少精度的前提下减少决策树的深度,进而提高效率。

简单说一下如何估计当前的节点的错误,也就是localModel().distribution().numIncorrect();

首先获得当前训练集上的一个分布,然后找出该分布里数量最多的那个属性的数量,认为是“正确的”,则其余的就是错误的。

getTrainingError就是对每个孩子节点做上述操作,然后结果相加。

(5)再来看看prune()方法,也是C45PruneableClassifierTree的BuildClassifier中的最后一个步骤。

该函数比较长,我就直接把对这个函数的分析写在注释里了。

[java] view plaincopy
  1. public void prune() throws Exception {  
  2.     double errorsLargestBranch;//这个树节点的孩子节点中,肯定有一个分到的数据最多,该值记录该孩子节点分类错误的用例数  
  3.     double errorsLeaf;//如果该节点成为了叶子节点,则分类错误的用例数量  
  4.     double errorsTree;//<span style="font-family: Arial, Helvetica, sans-serif;">该节点目前情况下,错误用例数量</span>  
  5.     int indexOfLargestBranch;//那个分到最多数据的孩子节点在son数组中的index  
  6.     C45PruneableClassifierTree largestBranch;//son[indexOfLargestBranch]  
  7.     int i;  
  8.   
  9.     if (!m_isLeaf){  
[java] view plaincopy
  1. //首先,如果是叶子节点,则先递归的队所有孩子几点进行prune()。  
  2.       for (i=0;i<m_sons.length;i++)  
  3.     son(i).prune();  
[java] view plaincopy
  1. //通过数据集的分布,很容易能找到indexOfLargetBranch  
  2.       indexOfLargestBranch = localModel().distribution().maxBag();  
  3.       if (m_subtreeRaising) {  
[java] view plaincopy
  1. //m_subtreeRaising是一个标志,代表可否使用该树的子树去替代该树,如果有了这个标志,就去计算最大的子树的错误数量  
[java] view plaincopy
  1. //否则就简单的标Double.Max_Value  
[java] view plaincopy
  1. //对于错误数量的估计不展开说了,简单来说依然是根据分布做一个统计(还要加一个基于m_CF的修正),如果不是叶子节点则递  
[java] view plaincopy
  1. //归的进行统计。  
  2.     errorsLargestBranch = son(indexOfLargestBranch).  
  3.       getEstimatedErrorsForBranch((Instances)m_train);  
  4.       } else {  
  5.     errorsLargestBranch = Double.MAX_VALUE;  
  6.       }  
  7.   
  8.       //估计一下如果该节点成为了叶子节点,则错误数量大概有多少  
  9.       errorsLeaf =   
  10.     getEstimatedErrorsForDistribution(localModel().distribution());  
[java] view plaincopy
  1. //估计该节点目前情况下,错误用例数量。  
  2.       errorsTree = getEstimatedErrors();  
  3.   
  4.      //Utils.smOrEq是smaller or equal即<=的意思  
  5.       if (Utils.smOrEq(errorsLeaf,errorsTree+0.1) &&  
  6.       Utils.smOrEq(errorsLeaf,errorsLargestBranch+0.1)){  
  7.   
  8.     // 如果当前节点作为叶子节点的错误量比整棵树都要低,并且当前节点比最大的子树的错误量也低,那么就把当前节点作//为叶子节点一定是一个最优的选择。  
  9.     m_sons = null;  
  10.     m_isLeaf = true;  
  11.           
  12.     // Get NoSplit Model for node.  
  13.     m_localModel = new NoSplit(localModel().distribution());  
  14.     return;//直接返回  
  15.       }  
  16.   
  17.       // Decide if largest branch is better choice  
  18.       // than whole subtree.  
  19.       if (Utils.smOrEq(errorsLargestBranch,errorsTree+0.1)){  
[java] view plaincopy
  1. //如果当前节点的错误用例数大于最大子树,则用最大子树替代当前节点。  
  2.     largestBranch = son(indexOfLargestBranch);  
  3.     m_sons = largestBranch.m_sons;  
  4.     m_localModel = largestBranch.localModel();  
  5.     m_isLeaf = largestBranch.m_isLeaf;  
  6.     newDistribution(m_train);  
  7.     prune();  
  8.       }  
  9.     }  
  10.   }  

一句话总结collapse和prune:prune或许会影响精度,collapse不会。


四、PruneableClassifierTree

在J48主流程里,根据m_reducedErrorPruning的不同会选择两个不同的ClassifierTree,刚才已经分析了一个,另外一个则是PruneeableClassifierTree。

(1)buildClassifier

[java] view plaincopy
  1. public void buildClassifier(Instances data)   
  2.      throws Exception {  
  3.   
  4.   // can classifier tree handle the data?  
  5.   getCapabilities().testWithFail(data);  
  6.   
  7.   // remove instances with missing class  
  8.   data = new Instances(data);  
  9.   data.deleteWithMissingClass();  
  10.     
  11.  Random random = new Random(m_seed);  
  12.  data.stratify(numSets);  
  13.  buildTree(data.trainCV(numSets, numSets - 1, random),  
  14.     data.testCV(numSets, numSets - 1), !m_cleanup);  
  15.  if (pruneTheTree) {  
  16.    prune();  
  17.  }  
  18.  if (m_cleanup) {  
  19.    cleanup(new Instances(data, 0));  
  20.  }  
  21. }  
和C45PruneableClassifierTree不同的是,buildTree的时候除了传入训练集,还传入了测试集,除此之外,少了Collapse步骤,其余都一样。

下面就看看传入了测试集的build和之前分析的build有什么不同之处。

(2)buildTree

[java] view plaincopy
  1. public void buildTree(Instances train, Instances test, boolean keepData)  
  2.        throws Exception {  
  3.       
  4.     Instances [] localTrain, localTest;  
  5.     int i;  
  6.       
  7.     if (keepData) {  
  8.       m_train = train;  
  9.     }  
  10.     m_isLeaf = false;  
  11.     m_isEmpty = false;  
  12.     m_sons = null;  
  13.     m_localModel = m_toSelectModel.selectModel(train, test);  
  14.     m_test = new Distribution(test, m_localModel);  
  15.     if (m_localModel.numSubsets() > 1) {  
  16.       localTrain = m_localModel.split(train);  
  17.       localTest = m_localModel.split(test);  
  18.       train = test = null;  
  19.       m_sons = new ClassifierTree [m_localModel.numSubsets()];  
  20.       for (i=0;i<m_sons.length;i++) {  
  21.     m_sons[i] = getNewTree(localTrain[i], localTest[i]);  
  22.     localTrain[i] = null;  
  23.     localTest[i] = null;  
  24.       }  
  25.     }else{  
  26.       m_isLeaf = true;  
  27.       if (Utils.eq(train.sumOfWeights(), 0))  
  28.     m_isEmpty = true;  
  29.       train = test = null;  
  30.     }  
  31.   }  
可以看到,代码基本一样,唯一不同的地方就是selectModel的时候会把test传进去,对于Model的实现会具体放到下篇博客中去讲述。

而prune也更为简单,去掉了subTreeRasing的特性。

[java] view plaincopy
  1.  public void prune() throws Exception {  
  2.    
  3.    if (!m_isLeaf) {  
  4.        
  5.      // Prune all subtrees.  
  6.      for (int i = 0; i < m_sons.length; i++)  
  7. son(i).prune();  
  8.        
  9.      // Decide if leaf is best choice.  
  10.      if (Utils.smOrEq(errorsForLeaf(),errorsForTree())) {  
  11.   
  12. // Free son Trees  
  13. m_sons = null;  
  14. m_isLeaf = true;  
  15.   
  16. // Get NoSplit Model for node.  
  17. m_localModel = new NoSplit(localModel().distribution());  
  18.      }  
  19.    }  
  20.  }  


五、总结

至此,对两种ClassifierTree的buildClassifier的分析差不多就结束了,总体上来讲,ClassifierTree是通过传入的Model来构建并维护分类树的结构,除此之外在构建完毕后会按照不同的逻辑进行剪枝。


对于篇开头提出的问题,目前可以回答问题4,简而言之就是根据已有数据集的分布,判断该树、该树的最大子树、以及该树作为叶子节点时的正确率,在此基础上进行剪枝。

下篇文章主要分析Model的实现,也就是如何根据属性把已有的数据集分解subInstances





0 0
原创粉丝点击