决策树——ID3算法的java实现

来源:互联网 发布:许冠杰光荣引退 知乎 编辑:程序博客网 时间:2024/06/05 03:19

所谓决策树就是用树来帮助我们做决策,从树的根节点开始一级一级的访问节点,直到叶子节点,也就完成了决策的过程。

决策树算法是描述用已知的样本来构建决策树的过程,这边用比较经典的“气候—玩”的例子来说明,

描述气候有很多指标(天色、温度、湿度、风速),想得到的决策结论是是否能玩(yes or no),

这边有一组已知的样本,存于weather.nominal.arff文件中,如下:

@relation weather.symbolic @attribute outlook {sunny, overcast, rainy}@attribute temperature {hot, mild, cool}@attribute humidity {high, normal}@attribute windy {TRUE, FALSE}@attribute play {yes, no} @datasunny,hot,high,FALSE,nosunny,hot,high,TRUE,noovercast,hot,high,FALSE,yesrainy,mild,high,FALSE,yesrainy,cool,normal,FALSE,yesrainy,cool,normal,TRUE,noovercast,cool,normal,TRUE,yessunny,mild,high,FALSE,nosunny,cool,normal,FALSE,yesrainy,mild,normal,FALSE,yessunny,mild,normal,TRUE,yesovercast,mild,high,TRUE,yesovercast,hot,normal,FALSE,yesrainy,mild,high,TRUE,no
arff文件用的比较多,很多地方有介绍,主要就是属性和数据,

@attribute开头的是属性,有属性名、可选值;

@data之后的每一行都是数据,逗号分隔每一个属性。


下面是一个ID3算法的java实现

import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.FileWriter;import java.io.IOException;import java.util.ArrayList;import java.util.Iterator;import java.util.LinkedList;import java.util.List;import java.util.regex.Matcher;import java.util.regex.Pattern;import org.dom4j.Document;import org.dom4j.DocumentHelper;import org.dom4j.Element;import org.dom4j.io.OutputFormat;import org.dom4j.io.XMLWriter;public class ID3 {private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据int decatt; // 决策变量在属性集中的索引public static final String patternString = "@attribute(.*)[{](.*?)[}]";Document xmldoc;Element root;public ID3() {xmldoc = DocumentHelper.createDocument();root = xmldoc.addElement("root");root.addElement("DecisionTree").addAttribute("value", "null");}// 读取arff文件,给attribute、attributevalue、data赋值public void readARFF(File file) {try {FileReader fr = new FileReader(file);BufferedReader br = new BufferedReader(fr);String line;Pattern pattern = Pattern.compile(patternString);while ((line = br.readLine()) != null) {Matcher matcher = pattern.matcher(line);if (matcher.find()) { // 读@attributeattribute.add(matcher.group(1).trim());String[] values = matcher.group(2).split(",");ArrayList<String> al = new ArrayList<String>(values.length);for (String value : values) {al.add(value.trim());}attributevalue.add(al);} else if (line.startsWith("@data")) { // 读@datawhile ((line = br.readLine()) != null) {if (line == "")continue;String[] row = line.split(",");data.add(row);}} else {continue;}}br.close();} catch (IOException e1) {e1.printStackTrace();}}// 设置决策变量public void setDec(int n) {if (n < 0 || n >= attribute.size()) {System.err.println("决策变量指定错误。");System.exit(2);}decatt = n;}public void setDec(String name) {int n = attribute.indexOf(name);setDec(n);}// 给一个样本(数组中是各种情况的计数),计算它的熵public double getEntropy(int[] arr) {double entropy = 0.0;int sum = 0;for (int i = 0; i < arr.length; i++) {entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)/ Math.log(2);sum += arr[i];}entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);entropy /= sum;return entropy;}// 给一个样本数组及样本的算术和,计算它的熵public double getEntropy(int[] arr, int sum) {double entropy = 0.0;for (int i = 0; i < arr.length; i++) {entropy -= arr[i] * Math.log(arr[i] + Double.MIN_VALUE)/ Math.log(2);}entropy += sum * Math.log(sum + Double.MIN_VALUE) / Math.log(2);entropy /= sum;return entropy;}public boolean infoPure(ArrayList<Integer> subset) {String value = data.get(subset.get(0))[decatt];for (int i = 1; i < subset.size(); i++) {String next = data.get(subset.get(i))[decatt];// equals表示对象内容相同,==表示两个对象指向的是同一片内存if (!value.equals(next))return false;}return true;}// 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵public double calNodeEntropy(ArrayList<Integer> subset, int index) {int sum = subset.size();double entropy = 0.0;int[][] info = new int[attributevalue.get(index).size()][];for (int i = 0; i < info.length; i++)info[i] = new int[attributevalue.get(decatt).size()];int[] count = new int[attributevalue.get(index).size()];for (int i = 0; i < sum; i++) {int n = subset.get(i);String nodevalue = data.get(n)[index];int nodeind = attributevalue.get(index).indexOf(nodevalue);count[nodeind]++;String decvalue = data.get(n)[decatt];int decind = attributevalue.get(decatt).indexOf(decvalue);info[nodeind][decind]++;}for (int i = 0; i < info.length; i++) {entropy += getEntropy(info[i]) * count[i] / sum;}return entropy;}// 构建决策树public void buildDT(String name, String value, ArrayList<Integer> subset,LinkedList<Integer> selatt) {Element ele = null;@SuppressWarnings("unchecked")List<Element> list = root.selectNodes("//" + name);Iterator<Element> iter = list.iterator();while (iter.hasNext()) {ele = iter.next();if (ele.attributeValue("value").equals(value))break;}if (infoPure(subset)) {ele.setText(data.get(subset.get(0))[decatt]);return;}int minIndex = -1;double minEntropy = Double.MAX_VALUE;for (int i = 0; i < selatt.size(); i++) {if (i == decatt)continue;double entropy = calNodeEntropy(subset, selatt.get(i));if (entropy < minEntropy) {minIndex = selatt.get(i);minEntropy = entropy;}}String nodeName = attribute.get(minIndex);selatt.remove(new Integer(minIndex));ArrayList<String> attvalues = attributevalue.get(minIndex);for (String val : attvalues) {ele.addElement(nodeName).addAttribute("value", val);ArrayList<Integer> al = new ArrayList<Integer>();for (int i = 0; i < subset.size(); i++) {if (data.get(subset.get(i))[minIndex].equals(val)) {al.add(subset.get(i));}}buildDT(nodeName, val, al, selatt);}}// 把xml写入文件public void writeXML(String filename) {try {File file = new File(filename);if (!file.exists())file.createNewFile();FileWriter fw = new FileWriter(file);OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式XMLWriter output = new XMLWriter(fw, format);output.write(xmldoc);output.close();} catch (IOException e) {System.out.println(e.getMessage());}}public static void main(String[] args) {ID3 inst = new ID3();inst.readARFF(new File(System.getProperty("user.dir") + "\\resource\\weather.nominal.arff"));inst.setDec("play");LinkedList<Integer> ll = new LinkedList<Integer>();for (int i = 0; i < inst.attribute.size(); i++) {if (i != inst.decatt)ll.add(i);}ArrayList<Integer> al = new ArrayList<Integer>();for (int i = 0; i < inst.data.size(); i++) {al.add(i);}inst.buildDT("DecisionTree", "null", al, ll);inst.writeXML(System.getProperty("user.dir") + "\\resource\\dt.xml");return;}}

读样本文件weather.nominal.arff,然后调用buildDT构建决策树,将决策树描述成xml输出到dt.xml文件中,

结果如下:

<?xml version="1.0" encoding="UTF-8"?><root>  <DecisionTree value="null">    <outlook value="sunny">      <humidity value="high">no</humidity>      <humidity value="normal">yes</humidity>    </outlook>    <outlook value="overcast">yes</outlook>    <outlook value="rainy">      <windy value="TRUE">no</windy>      <windy value="FALSE">yes</windy>    </outlook>  </DecisionTree></root>


0 0
原创粉丝点击