决策树算法原理及JAVA实现(ID3)

来源:互联网 发布:daxulu域名更换 编辑:程序博客网 时间:2024/05/02 01:30

0 引言 

决策树的目的在于构造一颗树像下面这样的树。

图1



图2


1. 如何构造呢?

1.1   参考资料。

      本例以图2为例,并参考了以下资料。
(1) http://www.cnblogs.com/zhangchaoyang/articles/2196631.html 

    写的东西非常经典。

(2) http://blog.sina.com.cn/s/blog_67bc5aa60100qays.html
(3)机器学习(Tom.Mitchell著) 第三章 决策树,里面详细介绍了信息增益的计算,和熵的计算。建议大家参考

1.2 数据集(训练数据集)



outlooktemperaturehumiditywindyplaysunnyhothighFALSEnosunnyhothighTRUEnoovercasthothighFALSEyesrainymildhighFALSEyesrainycoolnormalFALSEyesrainycoolnormalTRUEnoovercastcoolnormalTRUEyessunnymildhighFALSEnosunnycoolnormalFALSEyesrainymildnormalFALSEyessunnymildnormalTRUEyesovercastmildhighTRUEyesovercasthotnormalFALSEyesrainymildhighTRUEno

1.3 构造原则—选信息增益最大的

从图中知,一共有四个属性,outlook     temperature    humidity  windy,首先选哪一个作为树的第一个节点呢。答案是选信息增益越大的作为开始的节点。信息增益的计算公式如下:
Entropy(s)是熵,S样本集,Sv是子集。熵的计算公式如下:

举例:
根据以上的数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为

对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

outlooktemperaturehumiditywindyplay yesno yesno yesno yesnoyesnosunny23hot22high34FALSE6295overcast40mild42normal61TRUR33  rainy32cool31        

下面我们计算当已知变量outlook的值时,信息熵为多少。

outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast时,entropy=0

outlook=rainy时,entropy=0.971

而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247

同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。

gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。

1.4 为什么选信息增益最大的?

根据参考资料(2)的结论是:信息增益量越大,这个属性作为一棵树的根节点就能使这棵树更简洁(2)

1.5 递归:

接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。

依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。

1.6 递归结束的条件:

如果Examples都为正,那么返回label =+ 的单结点树Root ,熵为0
 如果Examples都为反,那么返回label =- 的单结点树Root ,熵为0
 如果Attributes为空,那么返回单结点树Root,label=Examples中最普遍的

2. 伪代码


3. java 实现

此仅贴主要的代码,源码请到我的github下载:
https://github.com/Bellonor/myHadoopProject/tree/master/com.homework/src/sequence/machinelearning/decisiontree/sequence/machinelearning/decisiontree/myid3
package sequence.machinelearning.decisiontree.myid3;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.FileWriter;import java.io.IOException;import java.util.ArrayList;import java.util.HashMap;import java.util.Iterator;import java.util.LinkedList;import java.util.List;import java.util.Map;import java.util.regex.Matcher;import java.util.regex.Pattern;import java.util.LinkedList;public class MyID3 {    private static LinkedList<String> attribute = new LinkedList<String>(); // 存储属性的名称    private static LinkedList<ArrayList<String>> attributevalue = new LinkedList<ArrayList<String>>(); // 存储每个属性的取值    private static LinkedList<String[]> data = new LinkedList<String[]>();; // 原始数据       public static final String patternString = "@attribute(.*)[{](.*?)[}]";public static String[] yesNo;public static TreeNode root;    /**     *      * @param lines 传入要分析的数据集     * @param index 哪个属性?attribute的index     */    public Double getGain(LinkedList<String[]> lines,int index){    Double gain=-1.0;    List<Double> li=new ArrayList<Double>();    //统计Yes No的次数    for(int i=0;i<yesNo.length;i++){    Double sum=0.0;    for(int j=0;j<lines.size();j++){    String[] line=lines.get(j);    //data为结构化数据,如果数据最后一列==yes,sum+1    if(line[line.length-1].equals(yesNo[i])){    sum=sum+1;    }    }    li.add(sum);    }    //计算Entropy(S)计算Entropy(S) 见参考书《机器学习 》Tom.Mitchell著  第3.4.1.2节    Double entropyS=TheMath.getEntropy(lines.size(), li);    //下面计算gain        List<String> la=attributevalue.get(index);    List<Point> lasv=new ArrayList<Point>();    for(int n=0;n<la.size();n++){    String attvalue=la.get(n);        //统计Yes No的次数    List<Double> lisub=new ArrayList<Double>();//如:sunny 是yes时发生的次数,是no发生的次数    Double Sv=0.0;//公式3.4中的Sv 见参考书《机器学习(Tom.Mitchell著)》        for(int i=0;i<yesNo.length;i++){        Double sum=0.0;        for(int j=0;j<lines.size();j++){        String[] line=lines.get(j);        //data为结构化数据,如果数据最后一列==yes,sum+1        if(line[index].equals(attvalue)&&line[line.length-1].equals(yesNo[i])){        sum=sum+1;        }        }        Sv=Sv+sum;//计算总数        lisub.add(sum);        }        //计算Entropy(S) 见参考书《机器学习(Tom.Mitchell著)》        Double entropySv=TheMath.getEntropy(Sv.intValue(), lisub);        //        Point p=new Point();        p.setSv(Sv);        p.setEntropySv(entropySv);        lasv.add(p);    }    gain=TheMath.getGain(entropyS,lines.size(),lasv);    return gain;    }    //寻找最大的信息增益,将最大的属性定为当前节点,并返回该属性所在list的位置和gain值    public Maxgain getMaxGain(LinkedList<String[]> lines){    if(lines==null||lines.size()<=0){    return null;    }    Maxgain maxgain = new Maxgain();    Double maxvalue=0.0;    int maxindex=-1;    for(int i=0;i<attribute.size();i++){    Double tmp=getGain(lines,i);    if(maxvalue< tmp){    maxvalue=tmp;    maxindex=i;    }    }    maxgain.setMaxgain(maxvalue);    maxgain.setMaxindex(maxindex);    return maxgain;    }    //剪取数组    public LinkedList<String[]>  filterLines(LinkedList<String[]> lines, String attvalue, int index){    LinkedList<String[]> newlines=new LinkedList<String[]>();    for(int i=0;i<lines.size();i++){    String[] line=lines.get(i);    if(line[index].equals(attvalue)){    newlines.add(line);    }    }        return newlines;    }    public void createDTree(){    root=new TreeNode();    Maxgain maxgain=getMaxGain(data);    if(maxgain==null){    System.out.println("没有数据集,请检查!");    }    int maxKey=maxgain.getMaxindex();    String nodename=attribute.get(maxKey);    root.setName(nodename);    root.setLiatts(attributevalue.get(maxKey));    insertNode(data,root,maxKey);    }    /**     *      * @param lines 传入的数据集,作为新的递归数据集     * @param node 深入此节点     * @param index 属性位置     */    public void insertNode(LinkedList<String[]> lines,TreeNode node,int index){    List<String> liatts=node.getLiatts();    for(int i=0;i<liatts.size();i++){    String attname=liatts.get(i);    LinkedList<String[]> newlines=filterLines(lines,attname,index);    if(newlines.size()<=0){        System.out.println("出现异常,循环结束");        return;        }    Maxgain maxgain=getMaxGain(newlines);    double gain=maxgain.getMaxgain();    Integer maxKey=maxgain.getMaxindex();    //不等于0继续递归,等于0说明是叶子节点,结束递归。    if(gain!=0){    TreeNode subnode=new TreeNode();    subnode.setParent(node);    subnode.setFatherAttribute(attname);    String nodename=attribute.get(maxKey);    subnode.setName(nodename);    subnode.setLiatts(attributevalue.get(maxKey));    node.addChild(subnode);    //不等于0,继续递归    insertNode(newlines,subnode,maxKey);    }else{    TreeNode subnode=new TreeNode();    subnode.setParent(node);    subnode.setFatherAttribute(attname);    //叶子节点是yes还是no?取新行中最后一个必是其名称,因为只有完全是yes,或完全是no的情况下才会是叶子节点    String[] line=newlines.get(0);    String nodename=line[line.length-1];    subnode.setName(nodename);    node.addChild(subnode);    }    }    }//输出决策树public void printDTree(TreeNode node){if(node.getChildren()==null){System.out.println("--"+node.getName());return;}System.out.println(node.getName());List<TreeNode> childs = node.getChildren();for (int i = 0; i < childs.size(); i++){System.out.println(childs.get(i).getFatherAttribute());printDTree(childs.get(i));}}    public static void main(String[] args) {// TODO Auto-generated method stub    MyID3 myid3 = new MyID3();    myid3.readARFF(new File("datafile/decisiontree/test/in/weather.nominal.arff"));    myid3.createDTree();    myid3.printDTree(root);}    //读取arff文件,给attribute、attributevalue、data赋值    public void readARFF(File file) {        try {            FileReader fr = new FileReader(file);            BufferedReader br = new BufferedReader(fr);            String line;            Pattern pattern = Pattern.compile(patternString);            while ((line = br.readLine()) != null) {            if (line.startsWith("@decision")) {                   line = br.readLine();                        if(line=="")                            continue;                        yesNo = line.split(",");                }            Matcher matcher = pattern.matcher(line);                if (matcher.find()) {                    attribute.add(matcher.group(1).trim());                    String[] values = matcher.group(2).split(",");                    ArrayList<String> al = new ArrayList<String>(values.length);                    for (String value : values) {                        al.add(value.trim());                    }                    attributevalue.add(al);                } else if (line.startsWith("@data")) {                    while ((line = br.readLine()) != null) {                        if(line=="")                            continue;                        String[] row = line.split(",");                        data.add(row);                    }                } else {                    continue;                }            }            br.close();        } catch (IOException e1) {            e1.printStackTrace();        }    }}


0 0
原创粉丝点击