数据挖掘之RandomForeast算法

来源:互联网 发布:php技术总监 编辑:程序博客网 时间:2024/05/25 05:37

RandomForest算法,精髓之处在于在建立决策树的时候,在每个节点进行属性选取时,是随机地选取部分属性,从中进行最优属性的选取,而不是在全部的所有属性中进行选择。建立了决策树森林之后,每次都要对这些不同的决策树进行预测,选出其中被预测最多的那个类别来作为最终的预测类别。在有5棵决策树时,我得出的对于离散属性的预测准确度为0.73,对于连续属性的预测准确度为0.96.

下面是我的RandomForest算法的代码,


/* * To change this template, choose Tools | Templates * and open the template in the editor. */package auxiliary;import java.util.ArrayList;import java.util.Arrays;import java.util.HashMap;import java.util.Iterator;import java.util.Map.Entry;/** * * @author daq */public class RandomForest extends Classifier {private int K=5;private ArrayList<DecisionTree> trees=new ArrayList<DecisionTree>();private HashMap<Double,Integer> map=null;private double newFeatures[][]=null;private double newLabels[]=null;    public RandomForest() {    }    @Override    public void train(boolean[] isCategory, double[][] features, double[] labels) {    for(int i=0;i<K;i++){    DecisionTree tree=new DecisionTree();    produceNewFeaturesLabels(features,labels);    tree.train(isCategory,newFeatures, newLabels);    trees.add(tree);    newFeatures=null;    newLabels=null;    }    DecisionTree tree=new DecisionTree();     tree.train(isCategory, features, labels);    trees.add(tree);    }        public void produceNewFeaturesLabels(double[][] features, double[] labels){    int size=features.length;    newFeatures=new double[size][];    newLabels=new double[size];    int length=features[0].length;    for(int i=0;i<size;i++){    int ran=(int) (Math.random()*size);    newFeatures[i]=Arrays.copyOf(features[ran],length);    newLabels[i]=labels[ran];    }    }    @Override    public double predict(double[] features) {    map=new HashMap<Double, Integer>();       for(int i=0;i<K;i++){        DecisionTree tree=trees.get(i);        double label=tree.predict(features);        if(map.get(label)==null)        map.put(label,1);        else        map.put(label, map.get(label)+1);        }        double maxIndex=0;        int max=-1;        Iterator<Entry<Double,Integer>> ite=map.entrySet().iterator();        while(ite.hasNext()){        Entry entry=ite.next();        if((Integer)entry.getValue()>max){        maxIndex=(Double)entry.getKey();        max=(Integer)entry.getValue();        }        }        return maxIndex;    }    }


原创粉丝点击