决策树算法的研究与应用

来源:互联网 发布:crossover软件 编辑:程序博客网 时间:2024/06/10 01:03

决策树算法的研究与应用

 

摘 要  决策树是归纳学习和数据挖掘的重要方法, 通常用来形成分类器和预测模型。概述了决策树分类算法, 指出了决策树算法的核心技术:测试属性的选择和树枝修剪技术。最后, 通过一个实例说明决策树分类在实际中的应用。

关键词  决策树,算法

 

Research and Application of Decision Tree Algorithm

 

ABSTRACT:  Decision tree is an import ant method in induction learning as well as in data mining , which can be used to form classification and predictive model. Introduces decision tree and points out i ts key techniques :the choice of testing feature and tree pruning .Finally , through an instance, this paper shows the application of decision tree in production .

KEYWORDS:  decision tree, Algorithm

1 引 言

    大数据时代已经到来,对数据的处理越来越受到人们的关注,人们迫切需要海量数据背后的重要信息和知识,发现数据中存在的关系和规则,获取有用的知识,并且根据现有数据对未来的发展做出预测。决策树分类算法C4.5算法是数据挖掘中最常用、最经典的分类算法,能够以图形化的形式表现挖掘的结果,从而方便于使用者快速做出决定或预测。决策树算法是以实例为基础的归纳学习算法, 以其易于提取显示规则、计算量相对较小、可以显示重要决策属性和较高的分类准确率等优点而得到广泛的应用。

2 决策树的基本思想及常用的决策树算法

    顾名思义, 决策树的结构, 就像是一棵树。它利用树的结构将数据记录进行分类, 树的一个叶节点就代表某个条件下的一个记录集, 根据记录字段的不同取值建立树的分支;在每个分支子集中重复建立下层节点和分支, 便可生成一棵决策树。对生成的决策树进行修剪, 很容易得到具有商业价值的信息, 供决策者

参考。ID3 是引用率较高的决策树算法之一, 是Quinlan提出的一个著名决策树生成方法。要构造尽可能小的决策树, 关键在于选择合适的产生分支的属性。而

ID3 算法的核心正是通过采用信息增益的方式来选择能够最好地将样本分类的属性。决策树分类算法从提出以来, 出现了很多算法, 比较常用的有:1986 年Quinlan 提出了著名的ID3算法。ID3 算法体现了决策树分类的优点:算法的理论清晰, 方法简单, 学习能力较强。其缺点是:只对比较小的数据集有效,且对噪声比较敏感,当训练数据集加大时, 决策树可能会随之改变, 并且在测试属性选择时, 它倾向于选择取值较多的属性。在ID3 算法的基础上, 1993 年Quinlan 又自己提出了改进算法———C4 .5算法。为了适应处理大规模数据集的需要, 后来又提出了若干改进的算法, 其中SLIQ(supervised learning in quest)和SPRINT (scalable parallelizable induction of decision trees)是比较有代表性的两个算法, PUBLIC(Pruning and Building Integrated in Classification)算法是一种很典型的在建树的同时进行剪枝的算法。此外,还有很多决策树分类算法。

3 决策树算法的核心技术

    建立决策树的目标是通过训练样本集, 建立目标变量关于各输入变量的分类预测模型, 全面实现输入变量和目标变量不同取值下的数据分组, 进而用于对新数据对象的分类和预测。当利用所建决策树对一个新数据对象进行分析时,决策树能够依据该数据输入变量的取值, 推断出相应目标变量的分类或取值。决策树技术中有各种各样的算法,这些算法都存在各自的优势和不足。目前, 从事

机器学习的专家学者们仍在潜心对现有算法的改进, 或研究更有效的新算法。总结起来, 决策树算法主要围绕两大核心问题展开:第一, 决策树的生长问题, 即利用训练样本集, 完成决策树的建立过程。第二, 决策树的剪枝问题,即利用检验样本集, 对形成的决策树进行优化处理。

4实验与结果分析

我们统计了14天的气象数据(指标包括outlook,temperature,humidity,windy),并已知这些天气是否打球(play)。如果给出新一天的气象指标数据:sunny,cool,high,TRUE,判断一下会不会去打球。以下使用ID3算法解决该问题:

                                  表1

outlook

temperature

humidity

windy

play

sunny

hot

high

FALSE

no

sunny

hot

high

TRUE

no

overcast

hot

high

FALSE

yes

rainy

mild

high

FALSE

yes

rainy

cool

normal

FALSE

yes

rainy

cool

normal

TRUE

no

overcast

cool

normal

TRUE

yes

sunny

mild

high

FALSE

no

sunny

cool

normal

FALSE

yes

rainy

mild

normal

FALSE

yes

sunny

mild

normal

TRUE

yes

overcast

mild

high

TRUE

yes

overcast

hot

normal

FALSE

yes

rainy

mild

high

TRUE

no

构造树的基本想法是随着树深度的增加,节点的熵迅速地降低。熵降低的速度越快越好,这样我们有望得到一棵高度最矮的决策树。

在没有给定任何天气信息时,根据历史数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为:

属性有4个:outlooktemperaturehumiditywindy。我们首先要决定哪个属性作树的根节点。

对每项指标分别统计:在不同的取值下打球和不打球的次数。

                              表 2

outlook

temperature

humidity

windy

play

 

yes

no

 

yes

no

 

yes

no

 

yes

no

yes

no

sunny

2

3

hot

2

2

high

3

4

FALSE

6

2

9

5

overcast

4

0

mild

4

2

normal

6

1

TRUR

3

3

 

 

rainy

3

2

cool

3

1

 

 

 

 

 

 

 

 

下面我们计算当已知变量outlook的值时,信息熵为多少。

outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast时,entropy=0

outlook=rainy时,entropy=0.971

而根据历史统计数据,outlook取值为sunnyovercastrainy的概率分别是5/144/145/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)0.940-0.693=0.247

同样可以计算出gain(temperature)=0.029gain(humidity)=0.152gain(windy)=0.048

gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook

接下来要确定N1temperaturehumidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)gain(humidity)gain(windy),选最大者为N1

依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。

Java实现

最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。

实验用的数据文件:

@relation weather.symbolic @attribute outlook {sunny, overcast, rainy}@attribute temperature {hot, mild, cool}@attribute humidity {high, normal}@attribute windy {TRUE, FALSE}@attribute play {yes, no} @datasunny,hot,high,FALSE,nosunny,hot,high,TRUE,noovercast,hot,high,FALSE,yesrainy,mild,high,FALSE,yesrainy,cool,normal,FALSE,yesrainy,cool,normal,TRUE,noovercast,cool,normal,TRUE,yessunny,mild,high,FALSE,nosunny,cool,normal,FALSE,yesrainy,mild,normal,FALSE,yessunny,mild,normal,TRUE,yesovercast,mild,high,TRUE,yesovercast,hot,normal,FALSE,yesrainy,mild,high,TRUE,no

程序代码:

+ View Code?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

package dt;

 

import java.io.BufferedReader;

import java.io.File;

import java.io.FileReader;

import java.io.FileWriter;

import java.io.IOException;

import java.util.ArrayList;

import java.util.Iterator;

import java.util.LinkedList;

import java.util.List;

import java.util.regex.Matcher;

import java.util.regex.Pattern;

 

import org.dom4j.Document;

import org.dom4j.DocumentHelper;

import org.dom4j.Element;

import org.dom4j.io.OutputFormat;

import org.dom4j.io.XMLWriter;

 

public class ID3 {

    private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称

    private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值

    private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据

    int decatt; // 决策变量在属性集中的索引

    public static final String patternString = "@attribute(.*)[{](.*?)[}]";

 

    Document xmldoc;

    Element root;

 

    public ID3() {

        xmldoc = DocumentHelper.createDocument();

        root = xmldoc.addElement("root");

        root.addElement("DecisionTree").addAttribute("value", "null");

    }

 

    public static void main(String[] args) {

        ID3 inst = new ID3();

        inst.readARFF(new File("/home/orisun/test/weather.nominal.arff"));

        inst.setDec("play");

        LinkedList<Integer> ll=new LinkedList<Integer>();

        for(int i=0;i<inst.attribute.size();i++){

            if(i!=inst.decatt)

                ll.add(i);

        }

        ArrayList<Integer> al=new ArrayList<Integer>();

        for(int i=0;i<inst.data.size();i++){

            al.add(i);

        }

        inst.buildDT("DecisionTree", "null", al, ll);

        inst.writeXML("/home/orisun/test/dt.xml");

        return;

    }

 

    //读取arff文件,给attributeattributevaluedata赋值

    public void readARFF(File file) {

        try {

            FileReader fr = new FileReader(file);

            BufferedReader br = new BufferedReader(fr);

            String line;

            Pattern pattern = Pattern.compile(patternString);

            while ((line = br.readLine()) != null) {

                Matcher matcher = pattern.matcher(line);

                if (matcher.find()) {

                    attribute.add(matcher.group(1).trim());

                    String[] values = matcher.group(2).split(",");

                    ArrayList<String> al = new ArrayList<String>(values.length);

                    for (String value : values) {

                        al.add(value.trim());

                    }

                    attributevalue.add(al);

                } else if (line.startsWith("@data")) {

                    while ((line = br.readLine()) != null) {

                        if(line=="")

                            continue;

                        String[] row = line.split(",");

                        data.add(row);

                    }

                } else {

                    continue;

                }

            }

            br.close();

        } catch (IOException e1) {

            e1.printStackTrace();

        }

    }

 

    //设置决策变量

    public void setDec(int n) {

        if (n < 0 || n >= attribute.size()) {

            System.err.println("决策变量指定错误。");

            System.exit(2);

        }

        decatt = n;

    }

    public void setDec(String name) {

        int n = attribute.indexOf(name);

        setDec(n);

    }

 

    //给一个样本(数组中是各种情况的计数),计算它的熵

    public double getEntropy(int[] arr) {

        double entropy = 0.0;

        int sum = 0;

        for (int i = 0; i < arr.length; i++) {

            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);

            sum += arr[i];

        }

        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);

        entropy /= sum;

        return entropy;

    }

 

    //给一个样本数组及样本的算术和,计算它的熵

    public double getEntropy(int[] arr, int sum) {

        double entropy = 0.0;

        for (int i = 0; i < arr.length; i++) {

            entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);

        }

        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);

        entropy /= sum;

        return entropy;

    }

 

    public boolean infoPure(ArrayList<Integer> subset) {

        String value = data.get(subset.get(0))[decatt];

        for (int i = 1; i < subset.size(); i++) {

            String next=data.get(subset.get(i))[decatt];

            //equals表示对象内容相同,==表示两个对象指向的是同一片内存

            if (!value.equals(next))

                return false;

        }

        return true;

    }

 

    // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵

    public double calNodeEntropy(ArrayList<Integer> subset, int index) {

        int sum = subset.size();

        double entropy = 0.0;

        int[][] info = new int[attributevalue.get(index).size()][];

        for (int i = 0; i < info.length; i++)

            info[i] = new int[attributevalue.get(decatt).size()];

        int[] count = new int[attributevalue.get(index).size()];

        for (int i = 0; i < sum; i++) {

            int n = subset.get(i);

            String nodevalue = data.get(n)[index];

            int nodeind = attributevalue.get(index).indexOf(nodevalue);

            count[nodeind]++;

            String decvalue = data.get(n)[decatt];

            int decind = attributevalue.get(decatt).indexOf(decvalue);

            info[nodeind][decind]++;

        }

        for (int i = 0; i < info.length; i++) {

            entropy += getEntropy(info[i]) * count[i] / sum;

        }

        return entropy;

    }

 

    // 构建决策树

    public void buildDT(String name, String value, ArrayList<Integer> subset,

            LinkedList<Integer> selatt) {

        Element ele = null;

        @SuppressWarnings("unchecked")

        List<Element> list = root.selectNodes("//"+name);

        Iterator<Element> iter=list.iterator();

        while(iter.hasNext()){

            ele=iter.next();

            if(ele.attributeValue("value").equals(value))

                break;

        }

        if (infoPure(subset)) {

            ele.setText(data.get(subset.get(0))[decatt]);

            return;

        }

        int minIndex = -1;

        double minEntropy = Double.MAX_VALUE;

        for (int i = 0; i < selatt.size(); i++) {

            if (i == decatt)

                continue;

            double entropy = calNodeEntropy(subset, selatt.get(i));

            if (entropy < minEntropy) {

                minIndex = selatt.get(i);

                minEntropy = entropy;

            }

        }

        String nodeName = attribute.get(minIndex);

        selatt.remove(new Integer(minIndex));

        ArrayList<String> attvalues = attributevalue.get(minIndex);

        for (String val : attvalues) {

            ele.addElement(nodeName).addAttribute("value", val);

            ArrayList<Integer> al = new ArrayList<Integer>();

            for (int i = 0; i < subset.size(); i++) {

                if (data.get(subset.get(i))[minIndex].equals(val)) {

                    al.add(subset.get(i));

                }

            }

            buildDT(nodeName, val, al, selatt);

        }

    }

 

    // xml写入文件

    public void writeXML(String filename) {

        try {

            File file = new File(filename);

            if (!file.exists())

                file.createNewFile();

            FileWriter fw = new FileWriter(file);

            OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式

            XMLWriter output = new XMLWriter(fw, format);

            output.write(xmldoc);

            output.close();

        } catch (IOException e) {

            System.out.println(e.getMessage());

        }

    }

}

最终生成的文件如下:

<?xml version="1.0" encoding="UTF-8"?> <root>  <DecisionTree value="null">    <outlook value="sunny">      <humidity value="high">no</humidity>      <humidity value="normal">yes</humidity>    </outlook>    <outlook value="overcast">yes</outlook>    <outlook value="rainy">      <windy value="TRUE">no</windy>      <windy value="FALSE">yes</windy>    </outlook>  </DecisionTree></root>

用图形象地表示就是:

5结束语

本文论述了决策树算法的概念,原理,以及常用的决策树算法,并在一个实例中应用的该算法。

 

参考文献

[ 1] 刘同明.数据挖掘技术及其应用 .北京:国防工业出版社, 2001 .

[ 2] 刘慧魏, 张 雷, 翟军昌.数据挖掘中决策树算法的研究及其改进 .辽宁师专学报, 2005, 7(4):23-26 .

[3] 冯少荣. 决策树算法的研究与改进. 厦门大学学报(自然科学版),200717(5):16-18

[4] 李慧慧,万武族. 决策树分类算法C4.5中连续属性过程处理的改进TP301. 1006-2475(2010)08-0008-03

 

 

 

 

0 0
原创粉丝点击