归纳决策树ID3(Java实现)

来源:互联网 发布:改图软件 编辑:程序博客网 时间:2024/05/20 17:10

先上问题吧,我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。

table 1

outlooktemperaturehumiditywindyplaysunnyhothighFALSEnosunnyhothighTRUEnoovercasthothighFALSEyesrainymildhighFALSEyesrainycoolnormalFALSEyesrainycoolnormalTRUEnoovercastcoolnormalTRUEyessunnymildhighFALSEnosunnycoolnormalFALSEyesrainymildnormalFALSEyessunnymildnormalTRUEyesovercastmildhighTRUEyesovercasthotnormalFALSEyesrainymildhighTRUEno

这个问题当然可以用朴素贝叶斯法求解,分别计算在给定天气条件下打球和不打球的概率,选概率大者作为推测结果。

现在我们使用ID3归纳决策树的方法来求解该问题。

预备知识:信息熵

熵是无序性(或不确定性)的度量指标。假如事件A的全概率划分是(A1,A2,...,An),每部分发生的概率是(p1,p2,...,pn),那信息熵定义为:

通常以2为底数,所以信息熵的单位是bit。

补充两个对数去处公式:

ID3算法

构造树的基本想法是随着树深度的增加,节点的熵迅速地降低。熵降低的速度越快越好,这样我们有望得到一棵高度最矮的决策树。

在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:

属性有4个:outlook,temperature,humidity,windy。我们首先要决定哪个属性作树的根节点。

对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

outlooktemperaturehumiditywindyplay yesno yesno yesno yesnoyesnosunny23hot22high34FALSE6295overcast40mild42normal61TRUR33  rainy32cool31        

下面我们计算当已知变量outlook的值时,信息熵为多少。

outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast时,entropy=0

outlook=rainy时,entropy=0.971

而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247

同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。

gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。

接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。

依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。

Java实现

最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。

实验用的数据文件:

@relation weather.symbolic@attribute outlook {sunny, overcast, rainy}@attribute temperature {hot, mild, cool}@attribute humidity {high, normal}@attribute windy {TRUE, FALSE}@attribute play {yes, no}@datasunny,hot,high,FALSE,nosunny,hot,high,TRUE,noovercast,hot,high,FALSE,yesrainy,mild,high,FALSE,yesrainy,cool,normal,FALSE,yesrainy,cool,normal,TRUE,noovercast,cool,normal,TRUE,yessunny,mild,high,FALSE,nosunny,cool,normal,FALSE,yesrainy,mild,normal,FALSE,yessunny,mild,normal,TRUE,yesovercast,mild,high,TRUE,yesovercast,hot,normal,FALSE,yesrainy,mild,high,TRUE,no

程序代码:

package schoolarship;import java.io.BufferedReader;import java.io.File;import java.io.FileInputStream;import java.io.FileWriter;import java.io.IOException;import java.io.InputStreamReader;import java.io.UnsupportedEncodingException;import java.util.ArrayList;import java.util.HashMap;import java.util.LinkedList;import java.util.Map;import java.util.Map.Entry;import java.util.Set;import java.util.regex.Matcher;import java.util.regex.Pattern;import org.dom4j.Document;import org.dom4j.DocumentHelper;import org.dom4j.Element;import org.dom4j.io.OutputFormat;import org.dom4j.io.XMLWriter;public class ID3 {//存储属性的名称,这里判别变量和决策变量一律称为“属性”private ArrayList<String> attribute = new ArrayList<String>();//存储每个属性(都是离散变量)的取值集合private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>();//存储所有的训练数据,这是一个二维数组private ArrayList<String[]> data = new ArrayList<String[]>();//决策变量的属性列表中的索引号int decatt;//用于匹配ARFF文件中的@attribute行public static final String patternString = "@attribute(.*)[{](.*?)[}]";//使用Dom4j读写XML文件Document xmldoc;Element root;//构造函数中初始化Dom元素public ID3() {xmldoc = DocumentHelper.createDocument();root = xmldoc.addElement("root");root.addElement("DecisionTree").addAttribute("value", "null");}public static void main(String[] args) {ID3 inst = new ID3();//读入训练文件inst.readARFF(new File("d:\\weather.arff"));//设置决策变量的名称inst.setDec("play");//将所有属性(决策变量除外)的索引号存入llLinkedList<Integer> ll = new LinkedList<Integer>();for (int i = 0; i < inst.attribute.size(); i++) {if (i != inst.decatt)ll.add(i);}//将全部训练数据的序号存入alArrayList<Integer> al = new ArrayList<Integer>();for (int i = 0; i < inst.data.size(); i++) {al.add(i);}//递归构建决策树inst.buildDT(inst.root, al, ll);//将决策树写入XML文件inst.writeXML("d:\\dt.xml");}//读取输入文件,为全局变量attribute、attributevalue和data赋值public void readARFF(File file) {try {FileInputStream fis = new FileInputStream(file);InputStreamReader isr = new InputStreamReader(fis,initBookEncode(fis));BufferedReader br = new BufferedReader(isr);String line;Pattern pattern = Pattern.compile(patternString);while ((line = br.readLine()) != null) {Matcher matcher = pattern.matcher(line);if (matcher.find()) {attribute.add(matcher.group(1).trim());String[] values = matcher.group(2).split(",");ArrayList<String> al = new ArrayList<String>(values.length);for (String value : values) {al.add(value.trim());}attributevalue.add(al);} else if (line.startsWith("@data")) {while ((line = br.readLine()) != null) {if (line.equals(""))continue;String[] row = line.split(",");data.add(row);}} else {continue;}}br.close();} catch (IOException e1) {e1.printStackTrace();}}//将参数n赋给全局变量decattpublic void setDec(int n) {if (n < 0 || n >= attribute.size()) {System.err.println("给定的决策变量名称有误");System.exit(2);}decatt = n;}//根据属性的名称设置全局变量decattpublic void setDec(String name) {int n = attribute.indexOf(name);setDec(n);}//计算信息熵。arr中存储各种情况的频数public double getEntropy(int[] arr) {int sum = 0;for (int i = 0; i < arr.length; i++) {sum += arr[i];}return getEntropy(arr, sum);}//计算信息熵。arr中存储各种情况的频数,sum给出频数的总和public double getEntropy(int[] arr, int sum) {if (sum == 0)return 0;double entropy = 0.0;for (int i = 0; i < arr.length; i++) {//加上Double.MIN_VALUE是为了防止出现log(0)的情况entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)/ Math.log(2);}entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);entropy /= sum;//由于上面加了Double.MIN_VALUE,所以算出来的熵可能会略大于1if (entropy > 1 && entropy - 1 < 0.00001)entropy = 1;return entropy;}//subset给写训练数据的一个子集(subset中存储的是每条数据的索引号),判断这些子集的决策变量值是否都相同public boolean infoPure(ArrayList<Integer> subset) {String value = data.get(subset.get(0))[decatt];for (int i = 1; i < subset.size(); i++) {String next = data.get(subset.get(i))[decatt];if (!value.equals(next))return false;}return true;}/** * 计算节点的信息熵 * @param subset 节点上所包含的数据子集 * @param index 节点以第index个属性作为判断的依据 * @return 节点的信息熵 */public double calNodeEntropy(ArrayList<Integer> subset, int index) {int sum = subset.size();double entropy = 0.0;int[][] info = new int[attributevalue.get(index).size()][];for (int i = 0; i < info.length; i++)info[i] = new int[attributevalue.get(decatt).size()];int[] count = new int[attributevalue.get(index).size()];for (int i = 0; i < sum; i++) {int n = subset.get(i);String nodevalue = data.get(n)[index];int nodeind = attributevalue.get(index).indexOf(nodevalue);count[nodeind]++;String decvalue = data.get(n)[decatt];int decind = attributevalue.get(decatt).indexOf(decvalue);info[nodeind][decind]++;}for (int i = 0; i < info.length; i++) {entropy += getEntropy(info[i]) * count[i] / sum;}return entropy;}// 递归构建决策树public void buildDT(Element ele, ArrayList<Integer> subset,LinkedList<Integer> selatt) {//指定name和value的节点不包含数据子集时,递归可以终止。同时要删除该节点if (subset.size() == 0){ele.getParent().remove(ele);return;}//selatt.size() == 0说明树已经达到最大的深度,即所有判别属性都已经用完了。//这个时候递归还没有终止说明训练数据中存在判别属性值完全相同,决策属性值却不相同的情况,取决策属性值最多的情况为最终结果if(selatt.size() == 0){Map<String,Integer> map=new HashMap<String,Integer>();for(int i:subset){String key=data.get(i)[decatt];Integer v=map.get(key);if(v!=null)map.put(key, v+1);elsemap.put(key, 1);}String decision="should not appear";int maxCount=-1;Set<Entry<String,Integer>> set=map.entrySet();for(Entry<String,Integer> entry:set){if(entry.getValue()>maxCount){maxCount=entry.getValue();decision=entry.getKey();}}ele.setText(decision);return; }//如果节点是纯的,那么就到达叶子节点了,给出决策,不需要继续递归了if (infoPure(subset)) {ele.setText(data.get(subset.get(0))[decatt]);return;}//选择下一个用于判别的属性。应该选熵最小的,因为这样信息增益最大int minIndex = -1;double minEntropy = Double.MAX_VALUE;for (int i = 0; i < selatt.size(); i++) {if (i == decatt)continue;double entropy = calNodeEntropy(subset, selatt.get(i));if (entropy < minEntropy) {minIndex = selatt.get(i);minEntropy = entropy;}}String nodeName = attribute.get(minIndex);//每次递归时selatt都会少一个元素,即去除刚刚选择的判别属性selatt.remove(new Integer(minIndex));//刚刚选择的属性有多少种取值,该节点就有多少个分枝。遍历这些分枝,递归完善子树。ArrayList<String> attvalues = attributevalue.get(minIndex);for (String val : attvalues) {Element child=ele.addElement(nodeName).addAttribute("value", val);ArrayList<Integer> al = new ArrayList<Integer>();for (int i = 0; i < subset.size(); i++) {if (data.get(subset.get(i))[minIndex].equals(val)) {al.add(subset.get(i));}}//注意bBuildDT()里面selatt会被改变,所以每次传递这个参数的时候要进行深复制buildDT(child, al, new LinkedList<Integer>(selatt));}}//将Dom写入XML文件public void writeXML(String filename) {try {File file = new File(filename);if (!file.exists())file.createNewFile();FileWriter fw = new FileWriter(file);OutputFormat format = OutputFormat.createPrettyPrint();XMLWriter output = new XMLWriter(fw, format);output.write(xmldoc);output.close();} catch (IOException e) {System.out.println(e.getMessage());}}/*正面这两个函数用于正确读取中文文件*/String changeToGBK(String ss, String code) {String temp = null;try {temp = new String(ss.getBytes(), code);} catch (UnsupportedEncodingException e) {e.printStackTrace();}return temp;}public String initBookEncode(FileInputStream fileInputStream) {String encode = "gb2312";try {byte[] head = new byte[3];fileInputStream.read(head);if (head[0] == -17 && head[1] == -69 && head[2] == -65)encode = "UTF-8";else if (head[0] == -1 && head[1] == -2)encode = "UTF-16";else if (head[0] == -2 && head[1] == -1)encode = "Unicode";} catch (IOException e) {System.out.println(e.getMessage());}return encode;}}


最终生成的文件如下:

<?xml version="1.0" encoding="UTF-8"?><root>  <DecisionTree value="null">    <outlook value="sunny">      <humidity value="high">no</humidity>      <humidity value="normal">yes</humidity>    </outlook>    <outlook value="overcast">yes</outlook>    <outlook value="rainy">      <windy value="TRUE">no</windy>      <windy value="FALSE">yes</windy>    </outlook>  </DecisionTree></root>

用图形象地表示就是: