weka[6] - Random Forest

来源:互联网 发布:观网络知识防诈骗有感 编辑:程序博客网 时间:2024/04/28 07:57

终于来到Random Forests啦。随机森林应该不难理解,算法本身就不细说了,直接进入代码!

buildClassifer:

  public void buildClassifier(Instances data) throws Exception {    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();    m_bagger = new Bagging();    RandomTree rTree = new RandomTree();    // set up the random tree options    m_KValue = m_numFeatures;    if (m_KValue < 1)      m_KValue = (int) Utils.log2(data.numAttributes()) + 1;    rTree.setKValue(m_KValue);    rTree.setMaxDepth(getMaxDepth());    // set up the bagger and build the forest    m_bagger.setClassifier(rTree);    m_bagger.setSeed(m_randomSeed);    m_bagger.setNumIterations(m_numTrees);    m_bagger.setCalcOutOfBag(true);    m_bagger.buildClassifier(data);  }
前三行再熟悉不过了。第四行, m_bagger初始化一个bagging类(其实random forests跟bagging区别的区别是base learner)。

RandomTree就是一棵随机树,后面讲(清楚随机森林的同学,已经大致猜到了这是棵怎么样的树)。

后面几部就是设置下参数而已。其实就跟bagging一模一样,只不过我们增加一些参数,并且把base learner换一换。

下面来看看随机森林的base learner - RandomTree。

buildClassifier:

  public void buildClassifier(Instances data) throws Exception {    // Make sure K value is in range    // m_KValue: number of instances for spliting    if (m_KValue > data.numAttributes() - 1)      m_KValue = data.numAttributes() - 1;    if (m_KValue < 1)      m_KValue = (int) Utils.log2(data.numAttributes()) + 1;
m_KValue就是每次分裂考虑的属性集大小。这里也很好理解,一般就是取log2(m)+1.
    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();    // only class? -> build ZeroR model    if (data.numAttributes() == 1) {      System.err          .println("Cannot build model (only class attribute present in data!), "              + "using ZeroR model instead!");      m_zeroR = new weka.classifiers.rules.ZeroR();      m_zeroR.buildClassifier(data);      return;    } else {      m_zeroR = null;    }
ZeroR是一个最简单的分类器,核心思想就是哪个类最多就输出那个类的标签(如果是回归,就输出average),这里因为只有1个属性,所以就用他了。
    // Figure out appropriate datasets    Instances train = null;    Instances backfit = null;    Random rand = data.getRandomNumberGenerator(m_randomSeed);    if (m_NumFolds <= 0) {      train = data;    } else {      data.randomize(rand);      data.stratify(m_NumFolds);      train = data.trainCV(m_NumFolds, 1, rand);      backfit = data.testCV(m_NumFolds, 1);    }
这段应该很熟悉。j48里有提到过,简单的说,就是把一个数据集随机排序,然后分成train和test2个子集(这里都是只取1份!)。
    // Create the attribute indices window    int[] attIndicesWindow = new int[data.numAttributes() - 1];    int j = 0;    for (int i = 0; i < attIndicesWindow.length; i++) {      if (j == data.classIndex())        j++; // do not include the class      attIndicesWindow[i] = j++;    }    // Compute initial class counts    double[] classProbs = new double[train.numClasses()];    for (int i = 0; i < train.numInstances(); i++) {      Instance inst = train.instance(i);      classProbs[(int) inst.classValue()] += inst.weight();    }
attIndices就是存储属性的index(数据集的class可能不在最后).后面部分就是计算每个class的权重和。
    // Build tree    m_Tree = new Tree();    m_Info = new Instances(data, 0);    m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0);    // Backfit if required    if (backfit != null) {      m_Tree.backfitData(backfit);    }  }
上面这个Tree是在RandomTree里的一个内部类。
下面把Tree的buildClassifer贴上来,然后简单说明一下(跟之前看的ID3、J48很多地方一样的)
    protected void buildTree(Instances data, double[] classProbs,        int[] attIndicesWindow, Random random, int depth) throws Exception {      // Make leaf if there are no training instances      if (data.numInstances() == 0) {        m_Attribute = -1;        m_ClassDistribution = null;        m_Prop = null;        return;      }      // Check if node doesn't contain enough instances or is pure      // or maximum depth reached      m_ClassDistribution = classProbs.clone();      if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum          || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],              Utils.sum(m_ClassDistribution))          || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {        // Make leaf        m_Attribute = -1;        m_Prop = null;        return;      }      // Compute class distributions and value of splitting      // criterion for each attribute      double val = -Double.MAX_VALUE;      double split = -Double.MAX_VALUE;      double[][] bestDists = null;      double[] bestProps = null;      int bestIndex = 0;      // Handles to get arrays out of distribution method      double[][] props = new double[1][0];      double[][][] dists = new double[1][0][0];      // Investigate K random attributes      int attIndex = 0;      int windowSize = attIndicesWindow.length;      int k = m_KValue;      boolean gainFound = false;      while ((windowSize > 0) && (k-- > 0 || !gainFound)) {        int chosenIndex = random.nextInt(windowSize);        attIndex = attIndicesWindow[chosenIndex];        // shift chosen attIndex out of window        attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];        attIndicesWindow[windowSize - 1] = attIndex;        windowSize--;        double currSplit = distribution(props, dists, attIndex, data);        double currVal = gain(dists[0], priorVal(dists[0]));        if (Utils.gr(currVal, 0))          gainFound = true;        if ((currVal > val) || ((currVal == val) && (attIndex < bestIndex))) {          val = currVal;          bestIndex = attIndex;          split = currSplit;          bestProps = props[0];          bestDists = dists[0];        }      }      // Find best attribute      m_Attribute = bestIndex;      // Any useful split found?      if (Utils.gr(val, 0)) {        // Build subtrees        m_SplitPoint = split;        m_Prop = bestProps;        Instances[] subsets = splitData(data);        m_Successors = new Tree[bestDists.length];        for (int i = 0; i < bestDists.length; i++) {          m_Successors[i] = new Tree();          m_Successors[i].buildTree(subsets[i], bestDists[i], attIndicesWindow,              random, depth + 1);        }        // If all successors are non-empty, we don't need to store the class        // distribution        boolean emptySuccessor = false;        for (int i = 0; i < subsets.length; i++) {          if (m_Successors[i].m_ClassDistribution == null) {            emptySuccessor = true;            break;          }        }        if (!emptySuccessor) {          m_ClassDistribution = null;        }      } else {        // Make leaf        m_Attribute = -1;      }    }

先判断递归出口,满足比如样本数目过少或者到达最大深度之类的,就停止。接着是随机选择k个属性,记录下Gain最大的属性。从这里我们也可以看出,随机森林选择特征集并不是对每棵树固定的,而是每个节点都是不一样的!

然后如果这个属性的Gain>0,那么意味着还要继续split。ID3那套熟悉的事情再做一遍,递归建立子树!

总的看下来,我似乎没有看到随机抽取部分样本这一步,因为默认的m_NumFolds=0,所以似乎weka里的RF每棵树用的是所有样本+部分特征。(有待考证)

Random Forests基本就是这些了。

关于树这块,感觉其实看懂了j48,其他的都很像了,就是分裂和剪枝策略的变化而已。

下篇应该是关于DecisionStump以及他的组合器adaboost。

0 0