Random Forest实战:Java实现 + 手写数字识别

来源:互联网 发布:sqlserver分页查询 编辑:程序博客网 时间:2024/06/06 21:45

Eclipse Project 在github上: https://github.com/zhangfaen/ML/tree/master/random_forest_classifier

package faen;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.HashMap;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Random;import java.util.Set;public class RandomForest {    static class Util {        public static void CHECK(boolean condition, String message) {            if (!condition) {                throw new RuntimeException(message);            }        }    }    private double[][] instances_;    private int[] targets_;    private int numOfTrees_;    private int numOfFeatures_;    private int maxDepth_;    private TreeNode[] trees_;    private static Random rand = new Random();    /**     * Train the RF model     *      * @param instances     * @param targets     * @param numOfTrees     * @param numOfFeatures     *            this could be -1, if so, the default value will be     *            len(features)^0.5     * @param maxDepth     *            this could be -1, if so, any leaf node will have only 1     *            instance.     */    public void train(double[][] instances, int[] targets, int numOfTrees, int numOfFeatures,            int maxDepth, int treeSize) {        Util.CHECK(instances.length == targets.length, "");        Util.CHECK(numOfTrees > 0, "");        this.instances_ = instances;        this.targets_ = targets;        this.numOfTrees_ = numOfTrees;        this.numOfFeatures_ = numOfFeatures;        this.maxDepth_ = maxDepth;        this.trees_ = new TreeNode[numOfTrees_];        for (int i = 0; i < trees_.length; i++) {            System.out.println("building the tree:" + i);            trees_[i] = buildTree(getRandomInstances(treeSize), 1);        }    }    // Get sub set of all instances randomly.    List<Integer> getRandomInstances(int numOfInstances) {        List<Integer> ret = new ArrayList<Integer>(numOfInstances);        while (ret.size() < numOfInstances) {            ret.add(rand.nextInt(instances_.length));        }        return ret;    }    // Get the majority class of all samples having indices.    private int getMajorClass(List<Integer> indices) {        Map<Integer, Integer> mii = new HashMap<Integer, Integer>();        int best = -1;        int ret = -1;        for (int index : indices) {            Integer v = mii.get(targets_[index]);            if (v == null) {                v = 0;            }            mii.put(targets_[index], v + 1);            if (v + 1 > best) {                best = v + 1;                ret = targets_[index];            }        }        return ret;    }    // Are all samples having indices have the same class?    private boolean haveSameClass(List<Integer> indices) {        for (int i = 1; i < indices.size(); i++) {            if (targets_[indices.get(i)] != targets_[indices.get(0)]) {                return false;            }        }        return true;    }    // Get a list of indices of features randomly.    private List<Integer> getRandomFeatures() {        Set<Integer> set = new HashSet<Integer>();        int featureSize = instances_[0].length;        while (set.size() < numOfFeatures_) {            set.add(rand.nextInt(featureSize));        }        List<Integer> ret = new ArrayList<Integer>();        ret.addAll(set);        return ret;    }    // Get the entropy of some samples.    private double getEntropy(List<Integer> indices, int from, int to) {        Util.CHECK(to <= indices.size(), "");        Map<Integer, Integer> mii = new HashMap<Integer, Integer>();        for (int i = from; i < to; i++) {            Integer v = mii.get(targets_[indices.get(i)]);            if (v == null) {                v = 0;            }            mii.put(targets_[indices.get(i)], v + 1);        }        double ret = 0;        for (Integer key : mii.keySet()) {            int v = mii.get(key);            ret += Math.log((to - from) * 1.0 / v);        }        return ret;    }    private TreeNode buildTree(List<Integer> indices, int curDepth) {        System.out.println("building tree, depth:" + curDepth);        if (maxDepth_ == curDepth) {            return new TreeNode(-1, -1, getMajorClass(indices), null, null, true);        }        if (haveSameClass(indices)) {            return new TreeNode(-1, -1, targets_[indices.get(0)], null, null, true);        }        List<Integer> featureInices = getRandomFeatures();        double bestEntropy = Double.MAX_VALUE;        int bestFeatureIndex = -1;        double splitValue = -1;        List<Integer> leftIndices = null;        List<Integer> rightIndices = null;        for (final int featureIndex : featureInices) {            Collections.sort(indices, new Comparator<Integer>() {                @Override                public int compare(Integer o1, Integer o2) {                    if (instances_[o1][featureIndex] < instances_[o2][featureIndex]) {                        return -1;                    } else if (instances_[o1][featureIndex] == instances_[o2][featureIndex]) {                        return o1 - o2;                    } else {                        return 1;                    }                }            });            int bestIndex = -1;            for (int i = 0; i < indices.size() - 1; i++) {                if (instances_[indices.get(i)][featureIndex] == instances_[indices.get(i + 1)][featureIndex]) {                    continue;                }                double entropy = 1.0 * (i + 1 - 0) / indices.size() * getEntropy(indices, 0, i + 1)                        + 1.0 * (indices.size() - (i + 1)) / indices.size()                        * getEntropy(indices, i + 1, indices.size());                if (entropy < bestEntropy) {                    bestEntropy = entropy;                    bestFeatureIndex = featureIndex;                    bestIndex = i;                    splitValue = instances_[indices.get(i)][featureIndex];                }            }            if (bestIndex >= 0) {                leftIndices = new ArrayList<Integer>();                rightIndices = new ArrayList<Integer>();                leftIndices.addAll(indices.subList(0, bestIndex + 1));                rightIndices.addAll(indices.subList(bestIndex + 1, indices.size()));            }        }        if (bestFeatureIndex >= 0) {            return new TreeNode(bestFeatureIndex, splitValue, -1, buildTree(leftIndices,                    curDepth + 1), buildTree(rightIndices, curDepth + 1), false);        } else {            // All instances have the same features.            return new TreeNode(-1, -1, getMajorClass(indices), null, null, true);        }    }    private int predicateByOneTree(TreeNode node, double[] instance) {        if (node.isLeafNode_) {            return node.target_;        }        if (instance[node.featureIndex_] <= node.value_) {            return predicateByOneTree(node.left_, instance);        } else {            return predicateByOneTree(node.right_, instance);        }    }    // Predicate one instance.    public int predicate(double[] instance) {        Map<Integer, Integer> mii = new HashMap<Integer, Integer>();        int bestTarget = -1;        int bestCount = -1;        for (TreeNode root : trees_) {            int target = predicateByOneTree(root, instance);            Integer v = mii.get(target);            if (v == null) {                v = 0;            }            mii.put(target, v + 1);            if (v + 1 > bestCount) {                bestCount = v + 1;                bestTarget = target;            }        }        return bestTarget;    }    // TreeNode of the decision tree.    private static class TreeNode {        public int featureIndex_;        public double value_;        public int target_;        public TreeNode left_;        public TreeNode right_;        public boolean isLeafNode_;        public TreeNode(int featureIndex, double value_, int target_, TreeNode left_,                TreeNode right_, boolean isLeafNode) {            this.featureIndex_ = featureIndex;            this.value_ = value_;            this.target_ = target_;            this.left_ = left_;            this.right_ = right_;            this.isLeafNode_ = isLeafNode;        }    }    public static void printTree(TreeNode node, String indedent) {        if (node.isLeafNode_) {            System.out.println(indedent + "target:" + node.target_);        } else {            System.out.println(indedent + "feature index:" + node.featureIndex_ + ", split value:"                    + node.value_);            printTree(node.left_, indedent + "    ");            printTree(node.right_, indedent + "    ");        }    }    public static void main(String[] args) {        //        double[][] instances = new double[][] { { 1, 1 }, { 1, -1 }, { -1, 1 }, { -1, -1 } };        int[] targets = new int[] { 1, 2, 3, 4 };        RandomForest rf = new RandomForest();        rf.train(instances, targets, 1, 2, -1, 10);        printTree(rf.trees_[0], "");    }}


0 0