Mahout随机森林算法源码分析(4)

来源:互联网 发布:水果网络销售平台 编辑:程序博客网 时间:2024/04/29 15:02

Mahout版本:0.7,hadoop版本:1.0.4,jdk:1.7.0_25 64bit。

Mahout系列之Decision Forest写了几篇,其中的一些过程并没有详细说明,这里就分析一下,作为Decision Forest算法系列的结束篇。

主要的问题包括:(1)在Build Forest中分析完了Step1Mapper后就没有向下分析了,而是直接进行TestForest的分析了,中间其实还是有很多操作的,比如:把Step1Mapper的Job的输出进行转换写入文件。(2)在BuildForest中没有分析当输入是Categorical的情况,这种情况下面执行的某些步骤是不一样的,主要是在DecisionTreeBuilder中的build方法中的区分。(3)在前一篇中最后的使用forest进行对数据的分类只是简要的说了下,这里详细分析下代码。(4)决策树同样可以做回归分析,在Describe阶段设置为回归问题就可以了,但是这里就不想做分析了。下面来分条进行分析:

(1)在BuildForest中提交任务后实际运行的类是Builder中的build方法中的代码。这里面的代码任务运行后的代码如下:

if (isOutput(conf)) {      log.debug("Parsing the output...");      DecisionForest forest = parseOutput(job);      HadoopUtil.delete(conf, outputPath);      return forest;    }
isOutput():

protected static boolean isOutput(Configuration conf) {    return conf.getBoolean("debug.mahout.rf.output", true);  }
可以看到这个函数去判断是否设置了debug.mahout.rf.output,如果没有设置则返回true,否则,就说明设置过了就按照设置的值来返回。这里一般都没有设置,所以就会运行if里面的代码先把job的输出传入到forest变量,然后删除job的输出。看parseOutput的操作:

protected DecisionForest parseOutput(Job job) throws IOException {    Configuration conf = job.getConfiguration();        int numTrees = Builder.getNbTrees(conf);        Path outputPath = getOutputPath(conf);        TreeID[] keys = new TreeID[numTrees];    Node[] trees = new Node[numTrees];            processOutput(job, outputPath, keys, trees);        return new DecisionForest(Arrays.asList(trees));  }
这里面又有一个processOutput函数,前面就是设置一些变量的size之类的,然后到processOutput函数,看这个函数:

protected static void processOutput(JobContext job,                                      Path outputPath,                                      TreeID[] keys,                                      Node[] trees) throws IOException {    Preconditions.checkArgument(keys == null && trees == null || keys != null && trees != null,        "if keys is null, trees should also be null");    Preconditions.checkArgument(keys == null || keys.length == trees.length, "keys.length != trees.length");    Configuration conf = job.getConfiguration();    FileSystem fs = outputPath.getFileSystem(conf);    Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);    // read all the outputs    int index = 0;    for (Path path : outfiles) {      for (Pair<TreeID,MapredOutput> record : new SequenceFileIterable<TreeID, MapredOutput>(path, conf)) {        TreeID key = record.getFirst();        MapredOutput value = record.getSecond();        if (keys != null) {          keys[index] = key;        }        if (trees != null) {          trees[index] = value.getTree();        }        index++;      }    }    // make sure we got all the keys/values    if (keys != null && index != keys.length) {      throw new IllegalStateException("Some key/values are missing from the output");    }  }
这里看到就是把job的输出按条读出然后写入到Node[] trees数组中,然后把数组转换为list,赋值给DecisionForest变量new DecisionForest(Arrays.asList(trees))。最后返回到BuildForest中DFUtils.storeWritable(getConf(), forestPath, forest);,这个主要是写文件,基本没啥内容了。

(2)当输入数据中存在有Categorical的属性列时,最先的不同就是在dataset的values属性。这个values数组当输入数据属性是Numerical的时候对应的值就是null,如果是Categorical的时候就会存入相应的离散值。其次就是在DecisionTreeBuilder中find the best split这一部分的代码(源文件中192行),这里计算Split的时候分为了Categorical和Numerical,如下:

public Split computeSplit(Data data, int attr) {    if (data.getDataset().isNumerical(attr)) {      return numericalSplit(data, attr);    } else {      return categoricalSplit(data, attr);    }  }
看categoricalSplit函数:

 private static Split categoricalSplit(Data data, int attr) {    double[] values = data.values(attr);    int[][] counts = new int[values.length][data.getDataset().nblabels()];    int[] countAll = new int[data.getDataset().nblabels()];    Dataset dataset = data.getDataset();    // compute frequencies    for (int index = 0; index < data.size(); index++) {      Instance instance = data.get(index);      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;      countAll[(int) dataset.getLabel(instance)]++;    }    int size = data.size();    double hy = entropy(countAll, size); // H(Y)    double hyx = 0.0; // H(Y|X)    double invDataSize = 1.0 / size;    for (int index = 0; index < values.length; index++) {      size = DataUtils.sum(counts[index]);      hyx += size * invDataSize * entropy(counts[index], size);    }    double ig = hy - hyx;    return new Split(attr, ig);  }

这里返回的Split只有两个属性,其实因为属性值是离散的,所以这里只用确定是这个值或者不是即可,不会还要说比较值的大小(而且也没法比)。
然后就是建立节点的部分了。获得最佳属性后,根据这个属性是否是Numerical而进入不同的代码块,如果是Categorical的话,进入:

else { // CATEGORICAL attribute      double[] values = data.values(best.getAttr());      // tree is complemented      Collection<Double> subsetValues = null;      if (complemented) {        subsetValues = Sets.newHashSet();        for (double value : values) {          subsetValues.add(value);        }        values = fullSet.values(best.getAttr());      }      int cnt = 0;      Data[] subsets = new Data[values.length];      for (int index = 0; index < values.length; index++) {        if (complemented && !subsetValues.contains(values[index])) {          continue;        }        subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));        if (subsets[index].size() >= minSplitNum) {          cnt++;        }      }      // size of the subset is less than the minSpitNum      if (cnt < 2) {        // branch is not split        double label;        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {          label = sum / data.size();        } else {          label = data.majorityLabel(rng);        }        log.debug("branch is not split Leaf({})", label);        return new Leaf(label);      }      selected[best.getAttr()] = true;      Node[] children = new Node[values.length];      for (int index = 0; index < values.length; index++) {        if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {          // tree is complemented          double label;          if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {            label = sum / data.size();          } else {            label = data.majorityLabel(rng);          }          log.debug("complemented Leaf({})", label);          children[index] = new Leaf(label);          continue;        }        children[index] = build(rng, subsets[index]);      }      selected[best.getAttr()] = alreadySelected;      childNode = new CategoricalNode(best.getAttr(), values, children);    }
其实上面的代码和Numerical差不多,可以说作为Numerical的一种特殊情况,即对于Numerical把其区分为等于属性值和不等于属性值即可(但是Numerical是分为小于和等于、大于两种)。其他基本就差不多了。

(3)用forest对数据Instance变量进行分类的代码是在DecisionForest的classify函数里面:

public double classify(Dataset dataset, Random rng, Instance instance) {    if (dataset.isNumerical(dataset.getLabelId())) {      double sum = 0;      int cnt = 0;      for (Node tree : trees) {        double prediction = tree.classify(instance);        if (prediction != -1) {          sum += prediction;          cnt++;        }      }      return sum / cnt;    } else {      int[] predictions = new int[dataset.nblabels()];      for (Node tree : trees) {        double prediction = tree.classify(instance);        if (prediction != -1) {          predictions[(int) prediction]++;        }      }            if (DataUtils.sum(predictions) == 0) {        return -1; // no prediction available      }
上面就是前篇讲到的所有树都对这个数据进行分类,然后按最多次数的那个类别即是最后的结果。但是一棵树是如何分类的?这个又分为了两种,好吧,应该不难猜,就是Numerical的树和Categorical的树。分别来看,首先是Numerical:

public double classify(Instance instance) {    if (instance.get(attr) < split) {      return loChild.classify(instance);    } else {      return hiChild.classify(instance);    }  }
看到它是去找它的子树去了,然后最后到哪里?其实是到了Leaf的classify函数了:

@Override  public double classify(Instance instance) {    return label;  }
这个也是一个递归的过程,其实就是建树过程的一个反过程而已,这样其实Categorical也是一样的了,只是要做些转换而已:

public double classify(Instance instance) {    int index = ArrayUtils.indexOf(values, instance.get(attr));    if (index == -1) {      // value not available, we cannot predict      return -1;    }    return childs[index].classify(instance);  }
这样基本就ok了,下次再看这个算法的时候应该是要分析回归问题了?



分享,成长,快乐

转载请注明blog地址:http://blog.csdn.net/fansy1990