weka[6] - Random Forest
来源:互联网 发布:观网络知识防诈骗有感 编辑:程序博客网 时间:2024/04/28 07:57
终于来到Random Forests啦。随机森林应该不难理解,算法本身就不细说了,直接进入代码!
buildClassifer:
public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); m_bagger = new Bagging(); RandomTree rTree = new RandomTree(); // set up the random tree options m_KValue = m_numFeatures; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes()) + 1; rTree.setKValue(m_KValue); rTree.setMaxDepth(getMaxDepth()); // set up the bagger and build the forest m_bagger.setClassifier(rTree); m_bagger.setSeed(m_randomSeed); m_bagger.setNumIterations(m_numTrees); m_bagger.setCalcOutOfBag(true); m_bagger.buildClassifier(data); }前三行再熟悉不过了。第四行, m_bagger初始化一个bagging类(其实random forests跟bagging区别的区别是base learner)。
RandomTree就是一棵随机树,后面讲(清楚随机森林的同学,已经大致猜到了这是棵怎么样的树)。
后面几部就是设置下参数而已。其实就跟bagging一模一样,只不过我们增加一些参数,并且把base learner换一换。
下面来看看随机森林的base learner - RandomTree。
buildClassifier:
public void buildClassifier(Instances data) throws Exception { // Make sure K value is in range // m_KValue: number of instances for spliting if (m_KValue > data.numAttributes() - 1) m_KValue = data.numAttributes() - 1; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes()) + 1;m_KValue就是每次分裂考虑的属性集大小。这里也很好理解,一般就是取log2(m)+1.
// can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); // only class? -> build ZeroR model if (data.numAttributes() == 1) { System.err .println("Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_zeroR = new weka.classifiers.rules.ZeroR(); m_zeroR.buildClassifier(data); return; } else { m_zeroR = null; }ZeroR是一个最简单的分类器,核心思想就是哪个类最多就输出那个类的标签(如果是回归,就输出average),这里因为只有1个属性,所以就用他了。
// Figure out appropriate datasets Instances train = null; Instances backfit = null; Random rand = data.getRandomNumberGenerator(m_randomSeed); if (m_NumFolds <= 0) { train = data; } else { data.randomize(rand); data.stratify(m_NumFolds); train = data.trainCV(m_NumFolds, 1, rand); backfit = data.testCV(m_NumFolds, 1); }这段应该很熟悉。j48里有提到过,简单的说,就是把一个数据集随机排序,然后分成train和test2个子集(这里都是只取1份!)。
// Create the attribute indices window int[] attIndicesWindow = new int[data.numAttributes() - 1]; int j = 0; for (int i = 0; i < attIndicesWindow.length; i++) { if (j == data.classIndex()) j++; // do not include the class attIndicesWindow[i] = j++; } // Compute initial class counts double[] classProbs = new double[train.numClasses()]; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); classProbs[(int) inst.classValue()] += inst.weight(); }attIndices就是存储属性的index(数据集的class可能不在最后).后面部分就是计算每个class的权重和。
// Build tree m_Tree = new Tree(); m_Info = new Instances(data, 0); m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0); // Backfit if required if (backfit != null) { m_Tree.backfitData(backfit); } }上面这个Tree是在RandomTree里的一个内部类。
下面把Tree的buildClassifer贴上来,然后简单说明一下(跟之前看的ID3、J48很多地方一样的)
protected void buildTree(Instances data, double[] classProbs, int[] attIndicesWindow, Random random, int depth) throws Exception { // Make leaf if there are no training instances if (data.numInstances() == 0) { m_Attribute = -1; m_ClassDistribution = null; m_Prop = null; return; } // Check if node doesn't contain enough instances or is pure // or maximum depth reached m_ClassDistribution = classProbs.clone(); if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils.sum(m_ClassDistribution)) || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) { // Make leaf m_Attribute = -1; m_Prop = null; return; } // Compute class distributions and value of splitting // criterion for each attribute double val = -Double.MAX_VALUE; double split = -Double.MAX_VALUE; double[][] bestDists = null; double[] bestProps = null; int bestIndex = 0; // Handles to get arrays out of distribution method double[][] props = new double[1][0]; double[][][] dists = new double[1][0][0]; // Investigate K random attributes int attIndex = 0; int windowSize = attIndicesWindow.length; int k = m_KValue; boolean gainFound = false; while ((windowSize > 0) && (k-- > 0 || !gainFound)) { int chosenIndex = random.nextInt(windowSize); attIndex = attIndicesWindow[chosenIndex]; // shift chosen attIndex out of window attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1]; attIndicesWindow[windowSize - 1] = attIndex; windowSize--; double currSplit = distribution(props, dists, attIndex, data); double currVal = gain(dists[0], priorVal(dists[0])); if (Utils.gr(currVal, 0)) gainFound = true; if ((currVal > val) || ((currVal == val) && (attIndex < bestIndex))) { val = currVal; bestIndex = attIndex; split = currSplit; bestProps = props[0]; bestDists = dists[0]; } } // Find best attribute m_Attribute = bestIndex; // Any useful split found? if (Utils.gr(val, 0)) { // Build subtrees m_SplitPoint = split; m_Prop = bestProps; Instances[] subsets = splitData(data); m_Successors = new Tree[bestDists.length]; for (int i = 0; i < bestDists.length; i++) { m_Successors[i] = new Tree(); m_Successors[i].buildTree(subsets[i], bestDists[i], attIndicesWindow, random, depth + 1); } // If all successors are non-empty, we don't need to store the class // distribution boolean emptySuccessor = false; for (int i = 0; i < subsets.length; i++) { if (m_Successors[i].m_ClassDistribution == null) { emptySuccessor = true; break; } } if (!emptySuccessor) { m_ClassDistribution = null; } } else { // Make leaf m_Attribute = -1; } }
先判断递归出口,满足比如样本数目过少或者到达最大深度之类的,就停止。接着是随机选择k个属性,记录下Gain最大的属性。从这里我们也可以看出,随机森林选择特征集并不是对每棵树固定的,而是每个节点都是不一样的!
然后如果这个属性的Gain>0,那么意味着还要继续split。ID3那套熟悉的事情再做一遍,递归建立子树!
总的看下来,我似乎没有看到随机抽取部分样本这一步,因为默认的m_NumFolds=0,所以似乎weka里的RF每棵树用的是所有样本+部分特征。(有待考证)
Random Forests基本就是这些了。
关于树这块,感觉其实看懂了j48,其他的都很像了,就是分裂和剪枝策略的变化而已。
下篇应该是关于DecisionStump以及他的组合器adaboost。
0 0
- weka[6] - Random Forest
- Random Forest
- Random Forest
- random forest
- Random Forest
- online random forest
- 随机森林Random Forest
- 随机森林--Random Forest
- 随机森林Random Forest
- Random Forest
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林Random Forest
- Online random forest
- [Machine Learning] Random Forest
- Random Forest(随机森林)
- sklearn random forest实验
- random forest python 实现
- malloc与calloc、realloc
- 黑马程序员——————————多线程2
- 【LeetCode with Python】 Binary Tree Maximum Path Sum
- 马努一个和你的人特任何关于同意
- hdu 1874 畅通工程续(最短路)
- weka[6] - Random Forest
- java密码学学习整理--PKI(公钥基础设)
- 使用Intent启动组件
- 正则表达式的使用
- 黑马程序员——————————字符串和基本数据类型对象包装类
- 周世黑鸭---面向全球招商
- 【第一次】创建github repo过程
- Python入门
- Windows实现Oracle数据的备份