Random Forest实战:Java实现 + 手写数字识别
来源:互联网 发布:sqlserver分页查询 编辑:程序博客网 时间:2024/06/06 21:45
Eclipse Project 在github上: https://github.com/zhangfaen/ML/tree/master/random_forest_classifier
package faen;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.HashMap;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Random;import java.util.Set;public class RandomForest { static class Util { public static void CHECK(boolean condition, String message) { if (!condition) { throw new RuntimeException(message); } } } private double[][] instances_; private int[] targets_; private int numOfTrees_; private int numOfFeatures_; private int maxDepth_; private TreeNode[] trees_; private static Random rand = new Random(); /** * Train the RF model * * @param instances * @param targets * @param numOfTrees * @param numOfFeatures * this could be -1, if so, the default value will be * len(features)^0.5 * @param maxDepth * this could be -1, if so, any leaf node will have only 1 * instance. */ public void train(double[][] instances, int[] targets, int numOfTrees, int numOfFeatures, int maxDepth, int treeSize) { Util.CHECK(instances.length == targets.length, ""); Util.CHECK(numOfTrees > 0, ""); this.instances_ = instances; this.targets_ = targets; this.numOfTrees_ = numOfTrees; this.numOfFeatures_ = numOfFeatures; this.maxDepth_ = maxDepth; this.trees_ = new TreeNode[numOfTrees_]; for (int i = 0; i < trees_.length; i++) { System.out.println("building the tree:" + i); trees_[i] = buildTree(getRandomInstances(treeSize), 1); } } // Get sub set of all instances randomly. List<Integer> getRandomInstances(int numOfInstances) { List<Integer> ret = new ArrayList<Integer>(numOfInstances); while (ret.size() < numOfInstances) { ret.add(rand.nextInt(instances_.length)); } return ret; } // Get the majority class of all samples having indices. private int getMajorClass(List<Integer> indices) { Map<Integer, Integer> mii = new HashMap<Integer, Integer>(); int best = -1; int ret = -1; for (int index : indices) { Integer v = mii.get(targets_[index]); if (v == null) { v = 0; } mii.put(targets_[index], v + 1); if (v + 1 > best) { best = v + 1; ret = targets_[index]; } } return ret; } // Are all samples having indices have the same class? private boolean haveSameClass(List<Integer> indices) { for (int i = 1; i < indices.size(); i++) { if (targets_[indices.get(i)] != targets_[indices.get(0)]) { return false; } } return true; } // Get a list of indices of features randomly. private List<Integer> getRandomFeatures() { Set<Integer> set = new HashSet<Integer>(); int featureSize = instances_[0].length; while (set.size() < numOfFeatures_) { set.add(rand.nextInt(featureSize)); } List<Integer> ret = new ArrayList<Integer>(); ret.addAll(set); return ret; } // Get the entropy of some samples. private double getEntropy(List<Integer> indices, int from, int to) { Util.CHECK(to <= indices.size(), ""); Map<Integer, Integer> mii = new HashMap<Integer, Integer>(); for (int i = from; i < to; i++) { Integer v = mii.get(targets_[indices.get(i)]); if (v == null) { v = 0; } mii.put(targets_[indices.get(i)], v + 1); } double ret = 0; for (Integer key : mii.keySet()) { int v = mii.get(key); ret += Math.log((to - from) * 1.0 / v); } return ret; } private TreeNode buildTree(List<Integer> indices, int curDepth) { System.out.println("building tree, depth:" + curDepth); if (maxDepth_ == curDepth) { return new TreeNode(-1, -1, getMajorClass(indices), null, null, true); } if (haveSameClass(indices)) { return new TreeNode(-1, -1, targets_[indices.get(0)], null, null, true); } List<Integer> featureInices = getRandomFeatures(); double bestEntropy = Double.MAX_VALUE; int bestFeatureIndex = -1; double splitValue = -1; List<Integer> leftIndices = null; List<Integer> rightIndices = null; for (final int featureIndex : featureInices) { Collections.sort(indices, new Comparator<Integer>() { @Override public int compare(Integer o1, Integer o2) { if (instances_[o1][featureIndex] < instances_[o2][featureIndex]) { return -1; } else if (instances_[o1][featureIndex] == instances_[o2][featureIndex]) { return o1 - o2; } else { return 1; } } }); int bestIndex = -1; for (int i = 0; i < indices.size() - 1; i++) { if (instances_[indices.get(i)][featureIndex] == instances_[indices.get(i + 1)][featureIndex]) { continue; } double entropy = 1.0 * (i + 1 - 0) / indices.size() * getEntropy(indices, 0, i + 1) + 1.0 * (indices.size() - (i + 1)) / indices.size() * getEntropy(indices, i + 1, indices.size()); if (entropy < bestEntropy) { bestEntropy = entropy; bestFeatureIndex = featureIndex; bestIndex = i; splitValue = instances_[indices.get(i)][featureIndex]; } } if (bestIndex >= 0) { leftIndices = new ArrayList<Integer>(); rightIndices = new ArrayList<Integer>(); leftIndices.addAll(indices.subList(0, bestIndex + 1)); rightIndices.addAll(indices.subList(bestIndex + 1, indices.size())); } } if (bestFeatureIndex >= 0) { return new TreeNode(bestFeatureIndex, splitValue, -1, buildTree(leftIndices, curDepth + 1), buildTree(rightIndices, curDepth + 1), false); } else { // All instances have the same features. return new TreeNode(-1, -1, getMajorClass(indices), null, null, true); } } private int predicateByOneTree(TreeNode node, double[] instance) { if (node.isLeafNode_) { return node.target_; } if (instance[node.featureIndex_] <= node.value_) { return predicateByOneTree(node.left_, instance); } else { return predicateByOneTree(node.right_, instance); } } // Predicate one instance. public int predicate(double[] instance) { Map<Integer, Integer> mii = new HashMap<Integer, Integer>(); int bestTarget = -1; int bestCount = -1; for (TreeNode root : trees_) { int target = predicateByOneTree(root, instance); Integer v = mii.get(target); if (v == null) { v = 0; } mii.put(target, v + 1); if (v + 1 > bestCount) { bestCount = v + 1; bestTarget = target; } } return bestTarget; } // TreeNode of the decision tree. private static class TreeNode { public int featureIndex_; public double value_; public int target_; public TreeNode left_; public TreeNode right_; public boolean isLeafNode_; public TreeNode(int featureIndex, double value_, int target_, TreeNode left_, TreeNode right_, boolean isLeafNode) { this.featureIndex_ = featureIndex; this.value_ = value_; this.target_ = target_; this.left_ = left_; this.right_ = right_; this.isLeafNode_ = isLeafNode; } } public static void printTree(TreeNode node, String indedent) { if (node.isLeafNode_) { System.out.println(indedent + "target:" + node.target_); } else { System.out.println(indedent + "feature index:" + node.featureIndex_ + ", split value:" + node.value_); printTree(node.left_, indedent + " "); printTree(node.right_, indedent + " "); } } public static void main(String[] args) { // double[][] instances = new double[][] { { 1, 1 }, { 1, -1 }, { -1, 1 }, { -1, -1 } }; int[] targets = new int[] { 1, 2, 3, 4 }; RandomForest rf = new RandomForest(); rf.train(instances, targets, 1, 2, -1, 10); printTree(rf.trees_[0], ""); }}
0 0
- Random Forest实战:Java实现 + 手写数字识别
- 手写数字识别实现
- Neural Network实战:Java实现Back Propagation算法 + 手写数字识别
- [TensorFlow实战] 构建LeNet实现手写数字识别
- TensorFlow实现识别手写数字
- cnn实现手写数字识别
- JAVA简单手写数字识别
- opencv实战之手写数字识别
- keras入门实战:手写数字识别
- TensorFlow实战—mnist手写数字识别
- TensorFlow实战(一)手写数字识别
- K近邻算法(一) python实现,手写数字识别(from机器学习实战)
- 《TensorfFlow实战》读书笔记(三) —— Tensorflow 实现 Softmax Regression 识别手写数字
- [Keras实战] 构建LeNet实现手写数字识别(mnist数据集)
- Tensorflow实战学习(二十四)【实现Softmax Regression(回归)识别手写数字】
- random forest python 实现
- random forest python 实现
- knn算法实现的数字手写识别
- android 将个人应用改为系统应用
- uva 11027(数论)
- LINGO 01-基础教程
- C语言位运算详解
- rails 总结 一种简单的验证登入的方法总结 OK
- Random Forest实战:Java实现 + 手写数字识别
- DevExpress Gridcontrol 表格头复选框 全选全不选
- ZPL指令中文参考地址
- java--练习day02
- IOS开发之在服务器端获取数据,保存网页的Demo学习
- 项目篇----为残障人群设计的体感控制系统
- ExtJs之表单(form)
- 解决sip来电时后台播放器暂时静音的效果
- GoogleMap离线开发小结