决策树代码
来源:互联网 发布:游戏脚本高级编程 编辑:程序博客网 时间:2024/05/04 15:36
import numpy as npfrom math import logimport operatorclass TreeNode:# class label for leaf nodelabel = None# split feat name for branching nodefeat = None# subTree dict {feat value : sub tree}subTrees = {}def __init__(self):self.label = Noneself.feat = Noneself.subTrees = {}""" add sub tree node to this tree node Parameters ----------feat : feat value of this sub tree branch subTree : sub tree node"""def addSubTree(self, feat, subTree):print 'subTree to add is %s, subTrees is %s}' % (feat, self.subTrees)if feat in self.subTrees:raise Exception('duplicate feat')self.subTrees[feat] = subTreeclass DecisionTree:def __init__(self):# a decision treetree = None# list of feat namefeats = None""" calculate Shannon Entropy Parameters ----------labels : labels for the dataReturns -------sum : Shannon Entropy"""def shannonEnt(self, labels):labelDict = {}for label in labels:labelDict[label] = labelDict.get(label, 0) + 1sum = 0length = len(labels)for value in labelDict.itervalues():prob = float(value) / lengthsum -= prob * log(prob, 2)return sum""" calculate majority label for the data Parameters ----------labels : labels for the dataReturns -------majority label"""def majorityCount(self, labels):labelCount = {}for label in labels:labelCount[label] = labelCount.get(label, 0) + 1print 'majority count is %s' % (labelCount)sortedLabelCount = sorted(labelCount.iteritems(), key = operator.itemgetter(1), reverse = True)return sortedLabelCount[0][0]def splitDataSet(self, dataSet, labels, featIndex, value):dataSetSlice = []labelSlice = []num = len(labels)for i in range(num):if dataSet[i][featIndex] == value:leftData = dataSet[i][:featIndex]leftData.extend(dataSet[i][featIndex+1:])dataSetSlice.append(leftData)labelSlice.append(labels[i])return dataSetSlice, labelSlice""" choose best feat to split data Parameters ----------dataSet : labels : list of label for the dataReturns -------bestFeat : best feat index to split the data"""def chooseBestFeatToSplit(self, dataSet, labels):bestFeat = -1minEnt = float("inf")numFeat = len(dataSet[0])for i in range(numFeat):uniqueFeats = set(data[i] for data in dataSet)entropy = 0for feat in uniqueFeats:labelSlice = [labels[i] for data in dataSet if data[i] == feat]prob = len(labelSlice) / float(len(labels))entropy += prob * self.shannonEnt(labelSlice) if entropy < minEnt:minEnt = entropybestFeat = ireturn bestFeat""" create a tree using data and their labels Parameters ----------dataSet : labels : list of label for the datafeats : list of name for each featReturns -------treeNode : root tree node of the tree"""def createTree(self, dataSet, labels, feats):print 'create tree'print dataSetprint labelsprint featstreeNode = TreeNode()print 'new tree node, subTrees is %s' % (treeNode.subTrees)if len(dataSet[0]) == 0:treeNode.label = self.majorityCount(labels)print 'get leaf node, label is %s' % (treeNode.label)return treeNodeif labels.count(labels[0]) == len(labels):print 'get leaf node, label is %s' % (labels[0])treeNode.label = labels[0]return treeNodebestFeat = self.chooseBestFeatToSplit(dataSet, labels)treeNode.feat = feats[bestFeat]featValues = set([data[bestFeat] for data in dataSet])for featValue in featValues:print 'featValue is %s : %s' % (feats[bestFeat], featValue)remainingFeats = feats[:]del(remainingFeats[bestFeat])dataSetSlice, labelSlice = self.splitDataSet(dataSet, labels, bestFeat, featValue)subTree = self.createTree(dataSetSlice, labelSlice, remainingFeats)treeNode.addSubTree(featValue, subTree)return treeNodedef storeTree(self, fileName):import picklefw = open(fileName, 'w')pickle.dump(self.tree, fw)fw.close()def loadTree(self, fileName):import picklefw = open(fileName)self.tree = pickle.load(fw)fw.close()""" train a tree using data and their labels Parameters ----------dataSet : labels : list of label for the datafeats : list of name for each feat"""def train(self, dataSet, labels, feats):self.feats = featsself.tree = self.createTree(dataSet, labels, feats)""" classify a piece of data using decision tree Parameters ----------data : a listReturns -------input data's label"""def classify(self, data):treeNode = self.treewhile (True) :if (treeNode.label != None):return treeNode.labelfeatIndex = self.feats.index(treeNode.feat)treeNode = treeNode.subTrees[data[featIndex]]
0 0
- 决策树代码
- 决策树代码
- 决策树代码实现
- 决策树代码实现
- 决策树算法代码
- 决策树代码python
- 决策树算法及代码
- 决策树(包括代码)
- 决策树算法与代码
- 决策树算法伪代码
- 代码笔记--决策树
- 决策树的python代码
- 决策树可视化python代码
- AI 决策树ID3 代码(c++)
- C4.5决策树+代码实践
- 决策树分类器+C代码
- R语言完整决策树代码
- 决策树ID3代码(Python)
- 【矩阵快速幂】ZOJ 2974 Just Pour the Water
- 杭电acm--2033
- HTML5的一些新特性学习
- hdu1506Largest Rectangle in a Histogram
- java学习之Set集合、HashSet
- 决策树代码
- Android Parcelable和Serializable的区别
- libsvm-2.84在MATLAB中使用遇到的刻骨铭心的error
- "Class not found: javac1.8"问题总结
- [LeetCode] Gray Code
- Javascript 时间相关随记
- 第三周实践项目1-顺序表的基本运算总结
- Android之自定义View以及画一个时钟
- uva 12096