机器学习实战决策树的java实现

来源:互联网 发布:w7网络无internet访问 编辑:程序博客网 时间:2024/06/05 22:21
package com.haolidong.Decisiontree;import java.util.Comparator;import java.util.HashMap;import java.util.Map.Entry;/** * @author haolidong * @Description:  [该类主要用于HashMap进行自定义的排序(从大到小)]   */public class ComparatorImpl implements Comparator<HashMap<String,Integer>>{@SuppressWarnings("unchecked")@Overridepublic int compare(HashMap<String, Integer> o1, HashMap<String, Integer> o2) {// TODO Auto-generated method stubEntry<String, Integer> obj1 = (Entry<String, Integer>) o1;  Entry<String, Integer> obj2 = (Entry<String, Integer>) o2;          return ((Integer) (obj2.getValue()) - (Integer) (obj1.getValue()));      }  }

package com.haolidong.Decisiontree;import java.util.ArrayList;/** *  * @author haolidong * @Description:  [该类主要用于保存特征信息] * @parameter data:  [主要保存特征矩阵] */public class Matrix {public  ArrayList<ArrayList<String>> data;public Matrix() {// TODO Auto-generated constructor stubdata = new ArrayList<ArrayList<String>>();}}

package com.haolidong.Decisiontree;import java.util.ArrayList;/** *  * @author haolidong * @Description:  [该类主要用于保存特征信息以及标签值] * @parameter labels:  [主要保存标签值] */public class CreateDataSet extends Matrix{public  ArrayList<String> labels;public CreateDataSet() {// TODO Auto-generated constructor stubsuper();labels = new ArrayList<String>();}/** * @author haolidong * @Description:  [机器学习实战决策树第一个案例的数据]  */public  void  initTest(){ArrayList<String> ab1 = new ArrayList<String>();ArrayList<String> ab2 = new ArrayList<String>();ArrayList<String> ab3 = new ArrayList<String>();ArrayList<String> ab4 = new ArrayList<String>();ArrayList<String> ab5 = new ArrayList<String>();ab1.add("1");ab1.add("1");ab1.add("yes");ab2.add("1");ab2.add("1");ab2.add("yes");ab3.add("1");ab3.add("0");ab3.add("no");ab4.add("0");ab4.add("1");ab4.add("no");ab5.add("0");ab5.add("1");ab5.add("no");data.add(ab1);data.add(ab2);data.add(ab3);data.add(ab4);data.add(ab5);labels.add("no surfacing");labels.add("flippers");}}

package com.haolidong.Decisiontree;import java.util.ArrayList;/** *  * @author haolidong * @Description:  [该类主要用于模拟Python的字典,最终保存生成树的信息] * @parameter  arrow:  [主要保存父节点指向自己的标签名字] * @parameter  name:  [主要保存当前节点的名字] * @parameter  arrDic:  [主要保存子节点的信息] */public class Dictionary {public String arrow;public String name;public ArrayList<Dictionary> arrDic;/** * @author haolidong * @Description:  [类的构造函数,分配空间,根节点只要arrow什么也不填] */public Dictionary() {// TODO Auto-generated constructor stubarrow = new String("");name = new String("");arrDic = new ArrayList<Dictionary>();}}

package com.haolidong.Decisiontree;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;import java.util.Collections;import java.util.Comparator;import java.util.HashMap;import java.util.HashSet;import java.util.List;import java.util.Map.Entry;public class Decisiontree {/** * @param args * @author haolidong * @Description:  [主函数主要对于各个实例进行测试]   */public static void main(String[] args) {testCreateTree();testGlass();}/** * @param inputTree 决策树 * @param testVec测试向量【输入各个特征值进行测试】 * @return返回最后的标签值 * @author         haolidong * @Description:    [主函数主要对于各个实例进行测试] */public static String classify(Dictionary inputTree,ArrayList<String> testVec){String result = new String();if(testVec.size()==0){result=inputTree.name;}else{for (int i = 0; i < inputTree.arrDic.size(); i++) {/*未来防止迭代没有结束,然后已经有返回值,这个时候后面的就不用继续进行了,testVec=0表示的是已经到达了叶子节点*/if(testVec.size()!=0){if(testVec.get(0).equals(inputTree.arrDic.get(i).arrow)){testVec.remove(testVec.get(0));result=classify(inputTree.arrDic.get(i),testVec);}}}}return result;}/** * @param dataSet   数据集 * @param labels    分类的标签值 * @return          返回最终的决策树 * @author         haolidong * @Description:    [生成决策树,当遇到标签值全部使用完,但是还是不能够把类完全分开,返回出现最多的标签值; *                  当到达子节点的时候,也要跳出函数,这个分别是前两个if判断,每一次都选择信息增益最大的, *                  然后递归进行划分,每一次递归都要去掉一个标签,一遍递归的终结  。 ] */public static Dictionary createTree(Matrix dataSet,ArrayList<String> labels){ArrayList<String> classList = new ArrayList<String>();HashSet<String> setList = new HashSet<String>();String temps=new String("");for (int i = 0; i < dataSet.data.size(); i++) {temps = dataSet.data.get(i).get(dataSet.data.get(i).size()-1);classList.add(temps);setList.add(temps);}if(setList.size()==1){Dictionary dtemp = new Dictionary();dtemp.name = classList.get(0);return dtemp;}if(dataSet.data.get(0).size()==1){Dictionary stemp = new Dictionary();stemp.arrow = classList.get(0);return stemp;}int bestFeat = chooseBestFeatureToSplit(dataSet);String bestFeatLabel = labels.get(bestFeat);Dictionary myTree = new Dictionary();myTree.name=bestFeatLabel;labels.remove(bestFeat);ArrayList<String> featValues = new ArrayList<String>();HashSet<String> uniqueVals = new HashSet<String>();for (int i = 0; i < dataSet.data.size(); i++) {featValues.add(dataSet.data.get(i).get(bestFeat));uniqueVals.add(dataSet.data.get(i).get(bestFeat));}for (String value : uniqueVals) {ArrayList<String> subLabels = new ArrayList<String>();for (int j = 0; j < labels.size(); j++) {subLabels.add(labels.get(j));}Dictionary tempTree = new Dictionary();tempTree = createTree(splitDataSet(dataSet, bestFeat, value),subLabels);tempTree.arrow = value;myTree.arrDic.add(tempTree);}return myTree;}/** * @param d * @author         haolidong * @Description:    [对于非叶子节输出他们自己的信息,然后判断字节点,子节点则直接输出] */                  public static void displayDic(Dictionary d){if(d.arrDic.size()!=0){System.out.print("{"+d.name);if(d.arrDic.size()==0){System.out.print("}");}else{System.out.print(":");for (int i = 0; i < d.arrDic.size(); i++) {if(i==0)System.out.print("{");System.out.print(d.arrDic.get(i).arrow+":");displayDic(d.arrDic.get(i));if(i!=d.arrDic.size()-1){System.out.print(",");}}System.out.print("}");System.out.print("}");}}else {System.out.print(d.name);}}/** * @param classList * @return 返回当前出现次数最多的标签值 * @author         haolidong * @Description:    [当且仅当标签全部用完时还没有把类别完全分离才使用的] */public static Dictionary majorityCnt(ArrayList<String> classList){HashMap<String,Integer> classCount = new HashMap<String,Integer>();String vote;for (int i = 0; i < classList.size(); i++) {vote = classList.get(i);if(classCount.containsKey(vote)==true){classCount.put(vote, classCount.get(vote)+1);}else{classCount.put(vote, 1);}}ArrayList<HashMap.Entry<String,Integer>> entries= sortMap(classCount);Dictionary dtemp = new Dictionary();dtemp.name = entries.get(0).getKey();;return dtemp;}/** * @param map       输入值是hashmap * @return          返回排好序的map * @author         haolidong * @Description:    [对map的排序,这里是从大到小] */public static ArrayList<HashMap.Entry<String,Integer>> sortMap(HashMap<String,Integer> map){       List<HashMap.Entry<String, Integer>> entries = new ArrayList<HashMap.Entry<String, Integer>>(map.entrySet());       Collections.sort(entries, new Comparator<HashMap.Entry<String, Integer>>() {       public int compare(HashMap.Entry<String, Integer> obj1 , HashMap.Entry<String, Integer> obj2) {               return obj2.getValue() - obj1.getValue();           }       });        return (ArrayList<Entry<String, Integer>>) entries;      }    /** * @param DataSet   特征矩阵 * @return          返回需要切分的特征向量的下标 * @author         haolidong * @Description:    [根据信息增益,选择最好的切分] */public static int chooseBestFeatureToSplit(Matrix DataSet){int numFeatures = DataSet.data.get(0).size()-1;double baseEntropy = calcShannonEnt(DataSet);double bestInfoGain = 0.0;int bestFeature=-1;HashSet<String> uniqueVals = new HashSet<String>();for (int i = 0; i < numFeatures; i++) {uniqueVals.clear();for (int j = 0; j < DataSet.data.size(); j++) {uniqueVals.add(DataSet.data.get(j).get(i));}double newEntropy = 0.0;double prob = 0.0;for(String value:uniqueVals){Matrix subDataSet = new Matrix();subDataSet = splitDataSet(DataSet, i, value);prob = 1.0*subDataSet.data.size()/DataSet.data.size();newEntropy = newEntropy + prob * calcShannonEnt(subDataSet);}double infoGain = baseEntropy - newEntropy;if(infoGain > bestInfoGain){bestInfoGain = infoGain;bestFeature = i;}}return bestFeature;}/** * @param DataSet  数据集 * @author haolidong * @Description:  [求香农熵:H=[求和]-p(x)log2 p(x)]  * @return 最后的香农熵 */public static double calcShannonEnt(Matrix DataSet){int numEntries = DataSet.data.size();HashMap<String,Integer> classCount = new HashMap<String,Integer>();String currentLabel;for (int i = 0; i < numEntries; i++) {currentLabel = DataSet.data.get(i).get(DataSet.data.get(i).size()-1);if(classCount.containsKey(currentLabel)==true){classCount.put(currentLabel, classCount.get(currentLabel)+1);}else{classCount.put(currentLabel, 1);}}double shannonEnt = 0.0;double prob = 0.0;for(HashMap.Entry<String,Integer> entry:classCount.entrySet()){prob = 1.0*entry.getValue()/numEntries;shannonEnt =shannonEnt -prob *Math.log(prob)/Math.log(2);}return shannonEnt;}/** * @param dataSet   输入数据集 * @param axis      输入删除的列下标 * @param value     把低axis列下标为value的值删除以后,把这一行放入ArrayList * @return          返回符合第axis列的特征向量为value的矩阵【删除了axis列】 * @author         haolidong * @Description:    [返回符合第axis列的特征向量为value的矩阵【删除了axis列] */public static Matrix splitDataSet(Matrix dataSet, int axis, String value){Matrix retDataSet = new Matrix();for (int i = 0; i < dataSet.data.size(); i++) {if(dataSet.data.get(i).get(axis).equals(value)){ArrayList<String> as = new ArrayList<String>();for (int j = 0; j < dataSet.data.get(i).size(); j++) {if(j!=axis){as.add(dataSet.data.get(i).get(j));}}retDataSet.data.add(as);}}return retDataSet;}/** * @return          返回数据集 * @author         haolidong * @Description:    [对香农熵的测试] */public static CreateDataSet testShannon(){CreateDataSet DataSet = new CreateDataSet();DataSet.initTest();System.out.println(calcShannonEnt(DataSet));return DataSet;}/** * @author         haolidong * @Description:    [对分割数据集的测试] */public static void testSplitDataSet() {CreateDataSet DataSet = new CreateDataSet();Matrix m =new Matrix();DataSet.initTest();m=splitDataSet(DataSet,0,"1");System.out.println(m);}/** * @author         haolidong * @Description:    [对最佳分割数据集的测试] */public static void testChooseBestFeatureToSplit() {CreateDataSet DataSet = new CreateDataSet();DataSet.initTest();System.out.println(chooseBestFeatureToSplit(DataSet));}/** * @author         haolidong * @Description:    [对于当标签全部用完时还没有把类别完全分离的函数进行测试] */public static void testmajortityCnt() {CreateDataSet DataSet = new CreateDataSet();DataSet.initTest();ArrayList<String> as = new ArrayList<String>();for (int i = 0; i < DataSet.data.size(); i++) {as.add(new String(DataSet.data.get(i).get(DataSet.data.get(i).size()-1)));}majorityCnt(as);}/** * @author         haolidong * @Description:    [对决策树显示结果的测试] */public static void testDisplayDir() {Dictionary d1 = new Dictionary();Dictionary d2 = new Dictionary();Dictionary d3 = new Dictionary();Dictionary d4 = new Dictionary();Dictionary d5 = new Dictionary();//Dictionary d6 = new Dictionary();//d6.name="hld";//d6.arrow="2";d1.arrow="0";d1.name="no";d2.arrow="1";d2.name="yes";d3.arrow="1";d3.name="flippers";d3.arrDic.add(d1);d3.arrDic.add(d2);//d4.arrDic.add(d6);d4.name="no";d4.arrow="0";//rootd5.name="no surfacing";d5.arrDic.add(d4);d5.arrDic.add(d3);displayDic(d5);}/** * @author         haolidong * @Description:    [验证决策树的分类效果] */public static void testClassify() {CreateDataSet DataSet = new CreateDataSet();ArrayList<String> testVec = new ArrayList<String>();DataSet.initTest();Dictionary myTree = new Dictionary();myTree=createTree(DataSet,DataSet.labels);testVec.add("1");testVec.add("0");//displayDic(myTree);System.out.println(classify(myTree,testVec));}/** * @author         haolidong * @Description:    [对书上最后一个例子的测试【对于隐形眼镜的测试】] */public static void testGlass(){String fileName = "I:\\machinelearninginaction\\Ch03\\lenses.txt";File file = new File(fileName);CreateDataSet DataSet = new CreateDataSet();        BufferedReader reader = null;        try {            reader = new BufferedReader(new FileReader(file));            String tempString = null;            // 一次读入一行,直到读入null为文件结束            while ((tempString = reader.readLine()) != null) {                // 显示行号                String[] strArr = tempString.split("\t");                ArrayList<String> as = new ArrayList<String>();                for (int i = 0; i < strArr.length; i++) {as.add(strArr[i]);}                DataSet.data.add(as);            }            reader.close();        } catch (IOException e) {            e.printStackTrace();        } finally {            if (reader != null) {                try {                    reader.close();                } catch (IOException e1) {                }            }        }        DataSet.labels.add(new String("age"));        DataSet.labels.add(new String("prescript"));        DataSet.labels.add(new String("astigmatic"));        DataSet.labels.add(new String("tearRate"));        Dictionary myTree = new Dictionary();        myTree=createTree(DataSet,DataSet.labels);displayDic(myTree);        }/** * @author         haolidong * @Description:    [对建树的测试] */public static void testCreateTree() {CreateDataSet DataSet = new CreateDataSet();DataSet.initTest();Dictionary myTree = new Dictionary();myTree=createTree(DataSet,DataSet.labels);displayDic(myTree);}}

0 0