决策分类树算法之ID3,C4.5算法系列
来源:互联网 发布:手机淘宝还用装旺信吗 编辑:程序博客网 时间:2024/04/28 22:27
一、引言
在最开始的时候,我本来准备学习的是C4.5算法,后来发现C4.5算法的核心还是ID3算法,所以又辗转回到学习ID3算法了,因为C4.5是他的一个改进。至于是什么改进,在后面的描述中我会提到。
二、ID3算法
ID3算法是一种分类决策树算法。他通过一系列的规则,将数据最后分类成决策树的形式。分类的根据是用到了熵这个概念。熵在物理这门学科中就已经出现过,表示是一个物质的稳定度,在这里就是分类的纯度的一个概念。公式为:
在ID3算法中,是采用Gain信息增益来作为一个分类的判定标准的。他的定义为:
每次选择属性中信息增益最大作为划分属性,在这里本人实现了一个java版本的ID3算法,为了模拟数据的可操作性,就把数据写到一个input.txt文件中,作为数据源,格式如下:
- Day OutLook Temperature Humidity Wind PlayTennis
- 1 Sunny Hot High Weak No
- 2 Sunny Hot High Strong No
- 3 Overcast Hot High Weak Yes
- 4 Rainy Mild High Weak Yes
- 5 Rainy Cool Normal Weak Yes
- 6 Rainy Cool Normal Strong No
- 7 Overcast Cool Normal Strong Yes
- 8 Sunny Mild High Weak No
- 9 Sunny Cool Normal Weak Yes
- 10 Rainy Mild Normal Weak Yes
- 11 Sunny Mild Normal Strong Yes
- 12 Overcast Mild High Strong Yes
- 13 Overcast Hot Normal Weak Yes
- 14 Rainy Mild High Strong No
- package DataMing_ID3;
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.FileReader;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Iterator;
- import java.util.Map;
- import java.util.Map.Entry;
- import java.util.Set;
- /**
- * ID3算法实现类
- *
- * @author lyq
- *
- */
- public class ID3Tool {
- // 类标号的值类型
- private final String YES = "Yes";
- private final String NO = "No";
- // 所有属性的类型总数,在这里就是data源数据的列数
- private int attrNum;
- private String filePath;
- // 初始源数据,用一个二维字符数组存放模仿表格数据
- private String[][] data;
- // 数据的属性行的名字
- private String[] attrNames;
- // 每个属性的值所有类型
- private HashMap<String, ArrayList<String>> attrValue;
- public ID3Tool(String filePath) {
- this.filePath = filePath;
- attrValue = new HashMap<>();
- }
- /**
- * 从文件中读取数据
- */
- private void readDataFile() {
- File file = new File(filePath);
- ArrayList<String[]> dataArray = new ArrayList<String[]>();
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
- data = new String[dataArray.size()][];
- dataArray.toArray(data);
- attrNum = data[0].length;
- attrNames = data[0];
- /*
- * for(int i=0; i<data.length;i++){ for(int j=0; j<data[0].length; j++){
- * System.out.print(" " + data[i][j]); }
- *
- * System.out.print("\n"); }
- */
- }
- /**
- * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
- */
- private void initAttrValue() {
- ArrayList<String> tempValues;
- // 按照列的方式,从左往右找
- for (int j = 1; j < attrNum; j++) {
- // 从一列中的上往下开始寻找值
- tempValues = new ArrayList<>();
- for (int i = 1; i < data.length; i++) {
- if (!tempValues.contains(data[i][j])) {
- // 如果这个属性的值没有添加过,则添加
- tempValues.add(data[i][j]);
- }
- }
- // 一列属性的值已经遍历完毕,复制到map属性表中
- attrValue.put(data[0][j], tempValues);
- }
- /*
- * for(Map.Entry entry : attrValue.entrySet()){
- * System.out.println("key:value " + entry.getKey() + ":" +
- * entry.getValue()); }
- */
- }
- /**
- * 计算数据按照不同方式划分的熵
- *
- * @param remainData
- * 剩余的数据
- * @param attrName
- * 待划分的属性,在算信息增益的时候会使用到
- * @param attrValue
- * 划分的子属性值
- * @param isParent
- * 是否分子属性划分还是原来不变的划分
- */
- private double computeEntropy(String[][] remainData, String attrName,
- String value, boolean isParent) {
- // 实例总数
- int total = 0;
- // 正实例数
- int posNum = 0;
- // 负实例数
- int negNum = 0;
- // 还是按列从左往右遍历属性
- for (int j = 1; j < attrNames.length; j++) {
- // 找到了指定的属性
- if (attrName.equals(attrNames[j])) {
- for (int i = 1; i < remainData.length; i++) {
- // 如果是父结点直接计算熵或者是通过子属性划分计算熵,这时要进行属性值的过滤
- if (isParent
- || (!isParent && remainData[i][j].equals(value))) {
- if (remainData[i][attrNames.length - 1].equals(YES)) {
- // 判断此行数据是否为正实例
- posNum++;
- } else {
- negNum++;
- }
- }
- }
- }
- }
- total = posNum + negNum;
- double posProbobly = (double) posNum / total;
- double negProbobly = (double) negNum / total;
- if (posProbobly == 1 || posProbobly == 0) {
- // 如果数据全为同种类型,则熵为0,否则带入下面的公式会报错
- return 0;
- }
- double entropyValue = -posProbobly * Math.log(posProbobly)
- / Math.log(2.0) - negProbobly * Math.log(negProbobly)
- / Math.log(2.0);
- // 返回计算所得熵
- return entropyValue;
- }
- /**
- * 为某个属性计算信息增益
- *
- * @param remainData
- * 剩余的数据
- * @param value
- * 待划分的属性名称
- * @return
- */
- private double computeGain(String[][] remainData, String value) {
- double gainValue = 0;
- // 源熵的大小将会与属性划分后进行比较
- double entropyOri = 0;
- // 子划分熵和
- double childEntropySum = 0;
- // 属性子类型的个数
- int childValueNum = 0;
- // 属性值的种数
- ArrayList<String> attrTypes = attrValue.get(value);
- // 子属性对应的权重比
- HashMap<String, Integer> ratioValues = new HashMap<>();
- for (int i = 0; i < attrTypes.size(); i++) {
- // 首先都统一计数为0
- ratioValues.put(attrTypes.get(i), 0);
- }
- // 还是按照一列,从左往右遍历
- for (int j = 1; j < attrNames.length; j++) {
- // 判断是否到了划分的属性列
- if (value.equals(attrNames[j])) {
- for (int i = 1; i <= remainData.length - 1; i++) {
- childValueNum = ratioValues.get(remainData[i][j]);
- // 增加个数并且重新存入
- childValueNum++;
- ratioValues.put(remainData[i][j], childValueNum);
- }
- }
- }
- // 计算原熵的大小
- entropyOri = computeEntropy(remainData, value, null, true);
- for (int i = 0; i < attrTypes.size(); i++) {
- double ratio = (double) ratioValues.get(attrTypes.get(i))
- / (remainData.length - 1);
- childEntropySum += ratio
- * computeEntropy(remainData, value, attrTypes.get(i), false);
- // System.out.println("ratio:value: " + ratio + " " +
- // computeEntropy(remainData, value,
- // attrTypes.get(i), false));
- }
- // 二者熵相减就是信息增益
- gainValue = entropyOri - childEntropySum;
- return gainValue;
- }
- /**
- * 计算信息增益比
- *
- * @param remainData
- * 剩余数据
- * @param value
- * 待划分属性
- * @return
- */
- private double computeGainRatio(String[][] remainData, String value) {
- double gain = 0;
- double spiltInfo = 0;
- int childValueNum = 0;
- // 属性值的种数
- ArrayList<String> attrTypes = attrValue.get(value);
- // 子属性对应的权重比
- HashMap<String, Integer> ratioValues = new HashMap<>();
- for (int i = 0; i < attrTypes.size(); i++) {
- // 首先都统一计数为0
- ratioValues.put(attrTypes.get(i), 0);
- }
- // 还是按照一列,从左往右遍历
- for (int j = 1; j < attrNames.length; j++) {
- // 判断是否到了划分的属性列
- if (value.equals(attrNames[j])) {
- for (int i = 1; i <= remainData.length - 1; i++) {
- childValueNum = ratioValues.get(remainData[i][j]);
- // 增加个数并且重新存入
- childValueNum++;
- ratioValues.put(remainData[i][j], childValueNum);
- }
- }
- }
- // 计算信息增益
- gain = computeGain(remainData, value);
- // 计算分裂信息,分裂信息度量被定义为(分裂信息用来衡量属性分裂数据的广度和均匀):
- for (int i = 0; i < attrTypes.size(); i++) {
- double ratio = (double) ratioValues.get(attrTypes.get(i))
- / (remainData.length - 1);
- spiltInfo += -ratio * Math.log(ratio) / Math.log(2.0);
- }
- // 计算机信息增益率
- return gain / spiltInfo;
- }
- /**
- * 利用源数据构造决策树
- */
- private void buildDecisionTree(AttrNode node, String parentAttrValue,
- String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
- node.setParentAttrValue(parentAttrValue);
- String attrName = "";
- double gainValue = 0;
- double tempValue = 0;
- // 如果只有1个属性则直接返回
- if (remainAttr.size() == 1) {
- System.out.println("attr null");
- return;
- }
- // 选择剩余属性中信息增益最大的作为下一个分类的属性
- for (int i = 0; i < remainAttr.size(); i++) {
- // 判断是否用ID3算法还是C4.5算法
- if (isID3) {
- // ID3算法采用的是按照信息增益的值来比
- tempValue = computeGain(remainData, remainAttr.get(i));
- } else {
- // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
- tempValue = computeGainRatio(remainData, remainAttr.get(i));
- }
- if (tempValue > gainValue) {
- gainValue = tempValue;
- attrName = remainAttr.get(i);
- }
- }
- node.setAttrName(attrName);
- ArrayList<String> valueTypes = attrValue.get(attrName);
- remainAttr.remove(attrName);
- AttrNode[] childNode = new AttrNode[valueTypes.size()];
- String[][] rData;
- for (int i = 0; i < valueTypes.size(); i++) {
- // 移除非此值类型的数据
- rData = removeData(remainData, attrName, valueTypes.get(i));
- childNode[i] = new AttrNode();
- boolean sameClass = true;
- ArrayList<String> indexArray = new ArrayList<>();
- for (int k = 1; k < rData.length; k++) {
- indexArray.add(rData[k][0]);
- // 判断是否为同一类的
- if (!rData[k][attrNames.length - 1]
- .equals(rData[1][attrNames.length - 1])) {
- // 只要有1个不相等,就不是同类型的
- sameClass = false;
- break;
- }
- }
- if (!sameClass) {
- // 创建新的对象属性,对象的同个引用会出错
- ArrayList<String> rAttr = new ArrayList<>();
- for (String str : remainAttr) {
- rAttr.add(str);
- }
- buildDecisionTree(childNode[i], valueTypes.get(i), rData,
- rAttr, isID3);
- } else {
- // 如果是同种类型,则直接为数据节点
- childNode[i].setParentAttrValue(valueTypes.get(i));
- childNode[i].setChildDataIndex(indexArray);
- }
- }
- node.setChildAttrNode(childNode);
- }
- /**
- * 属性划分完毕,进行数据的移除
- *
- * @param srcData
- * 源数据
- * @param attrName
- * 划分的属性名称
- * @param valueType
- * 属性的值类型
- */
- private String[][] removeData(String[][] srcData, String attrName,
- String valueType) {
- String[][] desDataArray;
- ArrayList<String[]> desData = new ArrayList<>();
- // 待删除数据
- ArrayList<String[]> selectData = new ArrayList<>();
- selectData.add(attrNames);
- // 数组数据转化到列表中,方便移除
- for (int i = 0; i < srcData.length; i++) {
- desData.add(srcData[i]);
- }
- // 还是从左往右一列列的查找
- for (int j = 1; j < attrNames.length; j++) {
- if (attrNames[j].equals(attrName)) {
- for (int i = 1; i < desData.size(); i++) {
- if (desData.get(i)[j].equals(valueType)) {
- // 如果匹配这个数据,则移除其他的数据
- selectData.add(desData.get(i));
- }
- }
- }
- }
- desDataArray = new String[selectData.size()][];
- selectData.toArray(desDataArray);
- return desDataArray;
- }
- /**
- * 开始构建决策树
- *
- * @param isID3
- * 是否采用ID3算法构架决策树
- */
- public void startBuildingTree(boolean isID3) {
- readDataFile();
- initAttrValue();
- ArrayList<String> remainAttr = new ArrayList<>();
- // 添加属性,除了最后一个类标号属性
- for (int i = 1; i < attrNames.length - 1; i++) {
- remainAttr.add(attrNames[i]);
- }
- AttrNode rootNode = new AttrNode();
- buildDecisionTree(rootNode, "", data, remainAttr, isID3);
- showDecisionTree(rootNode, 1);
- }
- /**
- * 显示决策树
- *
- * @param node
- * 待显示的节点
- * @param blankNum
- * 行空格符,用于显示树型结构
- */
- private void showDecisionTree(AttrNode node, int blankNum) {
- System.out.println();
- for (int i = 0; i < blankNum; i++) {
- System.out.print("\t");
- }
- System.out.print("--");
- // 显示分类的属性值
- if (node.getParentAttrValue() != null
- && node.getParentAttrValue().length() > 0) {
- System.out.print(node.getParentAttrValue());
- } else {
- System.out.print("--");
- }
- System.out.print("--");
- if (node.getChildDataIndex() != null
- && node.getChildDataIndex().size() > 0) {
- String i = node.getChildDataIndex().get(0);
- System.out.print("类别:"
- + data[Integer.parseInt(i)][attrNames.length - 1]);
- System.out.print("[");
- for (String index : node.getChildDataIndex()) {
- System.out.print(index + ", ");
- }
- System.out.print("]");
- } else {
- // 递归显示子节点
- System.out.print("【" + node.getAttrName() + "】");
- for (AttrNode childNode : node.getChildAttrNode()) {
- showDecisionTree(childNode, 2 * blankNum);
- }
- }
- }
- }
- /**
- * ID3决策树分类算法测试场景类
- * @author lyq
- *
- */
- public class Client {
- public static void main(String[] args){
- String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
- ID3Tool tool = new ID3Tool(filePath);
- tool.startBuildingTree(true);
- }
- }
- ------【OutLook】
- --Sunny--【Humidity】
- --High--类别:No[1, 2, 8, ]
- --Normal--类别:Yes[9, 11, ]
- --Overcast--类别:Yes[3, 7, 12, 13, ]
- --Rainy--【Wind】
- --Weak--类别:Yes[4, 5, 10, ]
- --Strong--类别:No[6, 14, ]
请从左往右观察这棵决策树,【】里面的是分类属性,---XXX----,XXX为属性的值,在叶子节点处为类标记。
对应的分类结果图:
这里的构造决策树和显示决策树采用的DFS的方法,所以可能会比较难懂,希望读者能细细体会,可以调试一下代码,一步步的跟踪会更加容易理解的。
三、C4.5算法
如果你已经理解了上面ID3算法的实现,那么理解C4.5也很容易了,C4.5与ID3在核心的算法是一样的,但是有一点所采用的办法是不同的,C4.5采用了信息增益率作为划分的根据,克服了ID3算法中采用信息增益划分导致属性选择偏向取值多的属性。信息增益率的公式为:
分母的位置是分裂因子,他的计算公式为:
和熵的计算公式比较像,具体的信息增益率的算法也在上面的代码中了,请关注着2个方法:
- // 选择剩余属性中信息增益最大的作为下一个分类的属性
- for (int i = 0; i < remainAttr.size(); i++) {
- // 判断是否用ID3算法还是C4.5算法
- if (isID3) {
- // ID3算法采用的是按照信息增益的值来比
- tempValue = computeGain(remainData, remainAttr.get(i));
- } else {
- // C4.5算法进行了改进,用的是信息增益率来比,克服了用信息增益选择属性时偏向选择取值多的属性的不足
- tempValue = computeGainRatio(remainData, remainAttr.get(i));
- }
- if (tempValue > gainValue) {
- gainValue = tempValue;
- attrName = remainAttr.get(i);
- }
- }
1、在构造决策树的过程中能对树进行剪枝。
2、能对连续性的值进行离散化的操作。
四、编码时遇到的一些问题
为了实现ID3算法,从理解阅读他的原理就已经用掉了比较多的时间,然后再尝试阅读别人写的C++版本的代码,又是看了几天,好不容易实现了2个算法,最后在构造树的过程中遇到了最大了麻烦,因为用到了递归构造树,对于其中节点的设计就显得至关重要了,也许我自己目前的设计也不是最优秀的。下面盘点一下我的程序的遇到的一些问题和存在的潜在的问题:
1、在构建决策树的时候,出现了remainAttr值缺少的情况,就是递归的时候remainAttr的属性划分移除掉之后,对于上次的递归操作的属性时受到影响了,后来发现是因为我remainAttr采用的是ArrayList,他是一个引用对象,通过引用传入的方式,对象用的还是同一个,所以果断重新建了一个ArrayList对象,问题就OK了。
- // 创建新的对象属性,对象的同个引用会出错
- ArrayList<String> rAttr = new ArrayList<>();
- for (String str : remainAttr) {
- rAttr.add(str);
- }
- buildDecisionTree(childNode[i], valueTypes.get(i), rData,
- rAttr, isID3);
- private void buildDecisionTree(AttrNode node, String parentAttrValue,
- String[][] remainData, ArrayList<String> remainAttr, boolean isID3) {
- node.setParentAttrValue(parentAttrValue);
- String attrName = "";
- double gainValue = 0;
- double tempValue = 0;
- // 如果只有1个属性则直接返回
- if (remainAttr.size() == 1) {
- System.out.println("attr null");
- return;
- }
- .....
- 决策分类树算法之ID3,C4.5算法系列
- 决策分类树算法之ID3,C4.5算法系列
- 决策分类算法-C4.5算法原理
- 分类算法-----决策树(包括ID3,C4.5)
- 决策树分类算法:ID3 & C4.5 & CART
- 数据挖掘算法----分类算法(ID3和C4.5)
- 决策数之C4.5算法
- 分类算法(5) ---- 决策树(ID3,C4.5,CTAR)
- 【机器学习】分类算法:决策树(ID3、C4.5、CART)
- 分类算法:ID3与C4.5及CART
- ID3决策树与C4.5决策树分类算法简述
- ID3和C4.5算法
- 机器学习之决策树分类算法(ID3 and C4.5)
- 分类算法之决策树C4.5算法
- 分类:ID3,C4.5
- 基于决策树系列算法(ID3, C4.5, CART, Random Forest, GBDT)的分类和回归探讨
- 决策树之ID3、C4.5、C5.0算法
- 决策树之ID3、C4.5、C5.0算法
- GSP序列模式分析算法
- Java ftp实现文件的上传和下载ftp,sftp sun.net.ftp.FtpProtocolException:Welcome message: SSH-2.0-OpenSSH_5.1
- iOS xib文件引入的两种方式
- Tips for DDIC and Search Help.docx
- Hiredis-redis cplusplus--redis3M
- 决策分类树算法之ID3,C4.5算法系列
- C 语言中的左值和右值。以及对比数组名和指针取数组元素的区别
- 深入理解Android之Gradle
- mybatis之map.xml文件的解读
- MongoDB is web scale
- SPI协议及工作原理分析
- 01【iOS总结】UIView、UILabel、UITextField、UIButton 、目标动作机制(+UIAlertView、UIAlertController)
- 并发编程网-线程池-说明
- 12.缺陷跟踪系统Mantis的问题生命周期和工作流