ID3决策树(Java实现)
来源:互联网 发布:阿里妈妈采集软件 编辑:程序博客网 时间:2024/05/09 21:19
说明
参考文章-归纳决策树ID3(Java实现),完成代码编写。
在原代码的基础上补充了预测函数,实现利用模型对新数据进行分类预测。
作者对ID3决策树的介绍-ID3决策树
决策树采用xml文件保存,使用Dom4J类库,点击下载
让Dom4J支持按XPath选择节点,还得引入包jaxen.jar,点击下载
源代码汇总,点击下载
思路
代码
输入文件采用ARFF格式,使用的训练数据文件如下:
train.arff
@relation weather.symbolic @attribute outlook {sunny,overcast,rainy} @attribute temperature {hot,mild,cool} @attribute humidity {high,normal} @attribute windy {TRUE,FALSE} @attribute play {yes,no} @data sunny,hot,high,FALSE,no sunny,hot,high,TRUE,no overcast,hot,high,FALSE,yes rainy,mild,high,FALSE,yes rainy,cool,normal,FALSE,yes rainy,cool,normal,TRUE,no overcast,cool,normal,TRUE,yes sunny,mild,high,FALSE,no sunny,cool,normal,FALSE,yes rainy,mild,normal,FALSE,yes sunny,mild,normal,TRUE,yes overcast,mild,high,TRUE,yes overcast,hot,normal,FALSE,yes rainy,mild,high,TRUE,no
ARFF(Attribute-Relation File Format):格式简单明了,分为两部分,第一部分交代属性及取值范围,第二部分则是数据部分(data)。
由于只是测试代码效果,测试集(predict.arff)也是上述数据,只是将类标相关的数据移除了。
ID3类
package ID3;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.FileWriter;import java.io.IOException;import java.lang.Character.Subset;import java.util.ArrayList;import java.util.Iterator;import java.util.LinkedList;import java.util.List;import java.util.regex.Matcher;import java.util.regex.Pattern;import org.dom4j.Document;import org.dom4j.DocumentHelper;import org.dom4j.Element;import org.dom4j.io.OutputFormat;import org.dom4j.io.XMLWriter;import org.w3c.dom.NodeList;public class ID3 { // 同时保留训练集和测试集的数据在模型中,防止训练集和测试集的列顺序不同 private ArrayList<String> trainAttribute = new ArrayList<String>(); // 存储训练集属性的名称 private ArrayList<ArrayList<String>> train_attributeValue = new ArrayList<ArrayList<String>>(); // 存储训练集每个属性的取值 private ArrayList<String> predictAttribute = new ArrayList<String>(); // 存储测试集属性的名称 private ArrayList<ArrayList<String>> predict_attributeValue = new ArrayList<ArrayList<String>>(); // 存储测试集每个属性的取值 private ArrayList<String[]> train_data = new ArrayList<String[]>(); // 训练集数据 ,即arff文件中的data字符串 private ArrayList<String[]> predict_data = new ArrayList<String[]>(); // 测试集数据 private String[] preLable; int decatt; // 决策变量在属性集中的索引(即类标所在列) public static final String patternString = "@attribute(.*)[{](.*?)[}]"; //正则表达,其中*? 表示重复任意次,但尽可能少重复,防止匹配到更后面的"}"符号 Document xmldoc; Element root; public ID3() { //创建并初始化xml文件,以用于储存决策树结构 xmldoc = DocumentHelper.createDocument(); root = xmldoc.addElement("root"); root.addElement("DecisionTree").addAttribute("value", "null"); } /** * 模型训练函数 * @param class_name 类标变量 * @param data_pathname 训练集 * @return xml决策树文件 */ public Document train(String class_name,String data_pathname){ read_trainARFF(new File(data_pathname)); setDec(class_name); LinkedList<Integer> ll=new LinkedList<Integer>(); //LinkList用于增删比ArrayList有优势 for(int i=0;i<trainAttribute.size();i++){ if(i!=decatt) ll.add(i); //防止类别变量不在最后一列发生错误 } ArrayList<Integer> al=new ArrayList<Integer>(); for(int i=0;i<train_data.size();i++){ al.add(i); } buildDT("DecisionTree", "null", al, ll); return xmldoc; } /** * 预测/分类函数(利用保留在类里的xml决策时模型进行预测) * @param data_pathname 测试集 * @return 预测结果集 */ public String[] predict(String data_pathname){ read_predictARFF(new File(data_pathname)); preLable=new String[predict_data.size()]; ArrayList<Integer> subset=new ArrayList<Integer>(); for(int i=0;i<predict_data.size();i++){ subset.add(i); } Element root=xmldoc.getRootElement(); Element DecisionTree=root.element("DecisionTree"); giveLable(DecisionTree, subset); return preLable; } /** * 用于计算分类结果的递归函数 * @param node 节点 * @param subset 子集(存储序号) */ public void giveLable(Element node, ArrayList<Integer> subset) { List<Element> list=node.elements(); if (list.size()==0) { //叶子节点 System.out.println("节点:"+node.getName()+"是叶子节点"); String lable=node.getTextTrim(); for(int index:subset ){ preLable[index]=lable; } }else{ //非叶子节点 for(Element e:list){ String name=e.getName(); String value=e.attribute("value").getValue(); int index=predictAttribute.indexOf(name); ArrayList<Integer> temp=new ArrayList<Integer>(); for(int i=0;i<subset.size();i++){ //筛选subset if(predict_data.get(subset.get(i))[index].equals(value)){ temp.add(subset.get(i)); } } giveLable(e, temp); } } } //读取arff文件,给attribute、attributevalue、data赋值 public void read_trainARFF(File file) { try { FileReader fr = new FileReader(file); BufferedReader br = new BufferedReader(fr); String line; Pattern pattern = Pattern.compile(patternString); while ((line = br.readLine()) != null) { Matcher matcher = pattern.matcher(line); if (matcher.find()) { trainAttribute.add(matcher.group(1).trim()); //获取第一个括号里的内容 //涉及取值,尽量加.trim(),后面也可以看到,即使是换行符也可能会造成字符串不相等 String[] values = matcher.group(2).split(","); ArrayList<String> al = new ArrayList<String>(values.length); for (String value : values) { al.add(value.trim()); } train_attributeValue.add(al); } else if (line.startsWith("@data")) { while ((line = br.readLine()) != null) { if(line=="") continue; String[] row = line.split(","); train_data.add(row); } } else { continue; } } br.close(); } catch (IOException e1) { e1.printStackTrace(); } } //读取arff文件,给attribute、attributevalue、data赋值 public void read_predictARFF(File file) { try { FileReader fr = new FileReader(file); BufferedReader br = new BufferedReader(fr); String line; Pattern pattern = Pattern.compile(patternString); while ((line = br.readLine()) != null) { Matcher matcher = pattern.matcher(line); if (matcher.find()) { predictAttribute.add(matcher.group(1).trim()); //获取第一个括号里的内容 //涉及取值,尽量加.trim(),后面也可以看到,即使是换行符也可能会造成字符串不相等 String[] values = matcher.group(2).split(","); ArrayList<String> al = new ArrayList<String>(values.length); for (String value : values) { al.add(value.trim()); } predict_attributeValue.add(al); } else if (line.startsWith("@data")) { while ((line = br.readLine()) != null) { if(line=="") continue; String[] row = line.split(","); predict_data.add(row); } } else { continue; } } br.close(); } catch (IOException e1) { e1.printStackTrace(); } } //设置决策变量 public void setDec(int n) { if (n < 0 || n >= trainAttribute.size()) { System.err.println("决策变量指定错误。"); System.exit(2); } decatt = n; } public void setDec(String name) { int n = trainAttribute.indexOf(name); setDec(n); } //给一个样本(数组中是各种情况的计数),计算它的熵 public double getEntropy(int[] arr) { double entropy = 0.0; int sum = 0; for (int i = 0; i < arr.length; i++) { //关于Double.MIN_VALUE好像和浮点精度有关,不是很懂 entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); sum += arr[i]; } entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); entropy /= sum; return entropy; } //给一个样本数组及样本的算术和,计算它的熵 public double getEntropy(int[] arr, int sum) { double entropy = 0.0; for (int i = 0; i < arr.length; i++) { entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2); } entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2); entropy /= sum; return entropy; } //判断类标是否统一,统一则之后即为叶节点(也可以设置为类别比例达到某一程度等其他指标) public boolean infoPure(ArrayList<Integer> subset) { String value = train_data.get(subset.get(0))[decatt]; for (int i = 1; i < subset.size(); i++) { String next=train_data.get(subset.get(i))[decatt]; if (!value.trim().equals(next.trim())) return false; } return true; } // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵 public double calNodeEntropy(ArrayList<Integer> subset, int index) { int sum = subset.size(); //System.out.println("sum="+sum); //System.out.println("index="+index); double entropy = 0.0; int[][] info = new int[train_attributeValue.get(index).size()][]; for (int i = 0; i < info.length; i++) info[i] = new int[train_attributeValue.get(decatt).size()]; int[] count = new int[train_attributeValue.get(index).size()]; for (int i = 0; i < sum; i++) { int n = subset.get(i); String nodevalue = train_data.get(n)[index]; int nodeind = train_attributeValue.get(index).indexOf(nodevalue); count[nodeind]++; String decvalue = train_data.get(n)[decatt]; //System.out.println(attributevalue.get(decatt).indexOf("no")); int decind = train_attributeValue.get(decatt).indexOf(decvalue.trim()); info[nodeind][decind]++; } for (int i = 0; i < info.length; i++) { entropy += getEntropy(info[i]) * count[i] / sum; } return entropy; } /** * 构建决策树 (核心函数) * @param node 节点名称 * @param value 节点值 * @param subset 数据子集 * @param selatt 属性子集 */ public void buildDT(String node, String value, ArrayList<Integer> subset, LinkedList<Integer> selatt) { Element ele = null; @SuppressWarnings("unchecked") List<Element> list = root.selectNodes("//"+node); Iterator<Element> iter=list.iterator(); while(iter.hasNext()){ ele=iter.next(); if(ele.attributeValue("value").equals(value)) break; } if (infoPure(subset)) { ele.setText(train_data.get(subset.get(0))[decatt]); //类标单一,直接写分类 return; } int minIndex = -1; double minEntropy = Double.MAX_VALUE; for (int i = 0; i < selatt.size(); i++) { if (i == decatt) continue; double entropy = calNodeEntropy(subset, selatt.get(i)); if (entropy < minEntropy) { minIndex = selatt.get(i); minEntropy = entropy; } } String nodeName= trainAttribute.get(minIndex); selatt.remove(new Integer(minIndex)); ArrayList<String> attvalues = train_attributeValue.get(minIndex); for (String val : attvalues) { //System.out.println(nodeName+"="+val); ele.addElement(nodeName).addAttribute("value", val); ArrayList<Integer> al = new ArrayList<Integer>(); for (int i = 0; i < subset.size(); i++) { if (train_data.get(subset.get(i))[minIndex].equals(val)) { al.add(subset.get(i)); } } buildDT(nodeName, val, al, selatt); } } /** * 把xml写入文件 * @param filename */ public void writeXML(String filename) { try { File file = new File(filename); if (!file.exists()) file.createNewFile(); FileWriter fw = new FileWriter(file); OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式 XMLWriter output = new XMLWriter(fw, format); output.write(xmldoc); output.close(); } catch (IOException e) { System.out.println(e.getMessage()); } } }
主函数
package ID3;public class Main { public static void main(String[] args) { ID3 inst=new ID3(); inst.train("play", "files/ID3/train.arff"); inst.writeXML("files/ID3/ID3_Tree.xml"); String[] preLable=inst.predict("files/ID3/predict.arff"); for(int i=0;i<preLable.length;i++){ System.out.println(i+preLable[i]); } }}
决策树xml文件
<?xml version="1.0" encoding="UTF-8"?><root> <DecisionTree value="null"> <outlook value="sunny"> <humidity value="high">no</humidity> <humidity value="normal">yes</humidity> </outlook> <outlook value="overcast">yes</outlook> <outlook value="rainy"> <windy value="TRUE">no</windy> <windy value="FALSE">yes</windy> </outlook> </DecisionTree></root>
阅读全文
0 0
- 决策树ID3(Java实现)
- ID3决策树(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- 归纳决策树ID3(Java实现)
- ID3决策树(R实现)
- java实现决策树ID3算法(文件读取)
- 决策树归纳(ID3属性选择度量)Java实现
- java实现决策树ID3算法(文件读取)
- 决策树分类器(ID3、C4.5 Java实现)
- 决策树算法原理及JAVA实现(ID3)
- 【JAVA实现】基于决策树的ID3算法
- 决策树之ID3算法java实现
- Swift 自动引用计数(ARC)
- STM32之RTC例程
- 微软操作系统 Windows Server 2012 R2 官方原版镜像 微软操作系统 Windows Server 2012 R2 官方原版镜像 Windows Server 2012 R2
- nRF52832 — 外部中断BSP(Board Support Package)
- ItemTouchHelper的使用
- ID3决策树(Java实现)
- CentOS自建Git服务端,Android Studio 添加自建远程库
- SpringMVC文件上传与下载
- 在windows下安装PyPdf2,将文件夹中的pdf文件合成为一个pdf文件
- python pass语句的作用
- IntelliJ IDEA导入Maven构建的Web工程
- Mybatis一级、二级缓存
- 为数据库添加索引
- Neighbor