周志华《机器学习》习题4.3
来源:互联网 发布:gta5online捏脸数据 编辑:程序博客网 时间:2024/05/23 02:01
为表4.3中数据生成一棵决策树。
代码是在《机器学习实战》的代码基础上改良的,借用了numpy, pandas之后明显简化了代码。表4.3的数据特征是离散属性和连续属性都有,问题就复杂在这里。话不多说,看代码。
先定义几个辅助函数,正常的思路是先想宏观算法,然后需要什么函数就定义什么函数。
import mathimport pandas as pdimport numpy as npfrom treePlotter import createPlotdef entropy(data): label_values = data[data.columns[-1]] #Returns object containing counts of unique values. counts = label_values.value_counts() s = 0 for c in label_values.unique(): freq = float(counts[c])/len(label_values) s -= freq*math.log(freq,2) return sdef is_continuous(data,attr): """Check if attr is a continuous attribute""" return data[attr].dtype == 'float64'def split_points(data,attr): """Returns Ta,Equation(4.7),p.84""" values = np.sort(data[attr].values) return [(x+y)/2 for x,y in zip(values[:-1],values[1:])]
treePlotter是《实战》里的模块,用来把决策树画出来。这里决策树是用字典表示的,key可以表示树的节点或分枝,表示节点的时候是属性,表示分枝的时候是属性值。value又是一个字典或字符串,是字符串的时候表示叶,也就是标记。这里的data是pandas里的DataFrame,形式上像一个表,对表的常见操作它都可以方便的解决。命名习惯跟书上一致。
再继续看怎么计算信息增益:
def discrete_gain(data,attr): V = data[attr].unique() s = 0 for v in V: data_v = data[data[attr]== v] s += float(len(data_v))/len(data)*entropy(data_v) return (entropy(data) - s,None)def continuous_gain(data,attr,points): """Equation(4.8),p.84,returns the max gain along with its splitting point""" entD = entropy(data) #gains is a list of pairs of the form (gain,t) gains = [] for t in points: d_plus = data[data[attr] > t] d_minus = data[data[attr] <= t] gain = entD - (float(len(d_plus))/len(data)*entropy(d_plus)+float(len(d_minus))/len(data)*entropy(d_minus)) gains.append((gain,t)) return max(gains)
离散属性的信息增益一目了然,最后返回的pair中的None是为了给后面的函数判断之用,看到None就知道是离散属性了。连续属性的信息增益的计算方法是对每个划分点
然后就是统管的信息增益函数:
def gain(data,attr): if is_continuous(data,attr): points = split_points(data,attr) return continuous_gain(data,attr,points) else: return discrete_gain(data,attr)
还要用到一个众数函数:
def majority(label_values): counts = label_values.value_counts() return counts.index[0]
我们的id3终于登场了:
def id3(data): attrs = data.columns[:-1] #attrWithGain is of the form [(attr,(gain,t))], t is None if attr is discrete attrWithGain = [(a,gain(data,a)) for a in attrs] attrWithGain.sort(key = lambda tup:tup[1][0],reverse = True) return attrWithGain[0]
它对每个属性都计算了信息增益,最后返回信息增益最大的那个属性,连带两个附加值,形式是(attr,(gain,t))。
最后造树:
def createTree(data,split_function): label = data.columns[-1] label_values = data[label] #Stop when all classes are equal if len(label_values.unique()) == 1: return label_values.values[0] #When no more features, or only one feature with same values, return majority if data.shape[1] == 1 or (data.shape[1]==2 and len(data.T.ix[0].unique())==1): return majority(label_values) bestAttr,(g,t) = split_function(data) #If bestAttr is discrete if t is None: #In this tree,a key is a node, the value is a list of trees,also a dictionary myTree = {bestAttr:{}} values = data[bestAttr].unique() for v in values: data_v = data[data[bestAttr]== v] attrsAndLabel = data.columns.tolist() attrsAndLabel.remove(bestAttr) data_v = data_v[attrsAndLabel] myTree[bestAttr][v] = createTree(data_v,split_function) return myTree #If bestAttr is continuous else: t = round(t,3) node = bestAttr+'<='+str(t) myTree = {node:{}} values = ['yes','no'] for v in values: data_v = data[data[bestAttr] <= t] if v == 'yes' else data[data[bestAttr] > t] myTree[node][v] = createTree(data_v,split_function) return myTree
这个我就不细说了,还得自己看。值得一提的是离散属性的下一次递归把当前的离散值删掉了,attrsAndLabel.remove(bestAttr),因为不允许这个属性出现在后续的分枝中。然而连续属性的时候,不删,允许继续出现。这个好理解,毕竟对连续属性用的是二分法,可能需要多个二分才能把情况搞清。
测试一下:
if __name__ == "__main__": f = pd.read_csv(filepath_or_buffer = 'dataset/watermelon3.0en.csv', sep = ',') data = f[f.columns[1:]] tree = createTree(data,id3) print tree createPlot(tree)
我把原表翻译成英文了,因为中文的打印字典不显示汉字,画图的时候甚至直接不能画。
id,color,root,knock,texture,umbilical,touch,density,sugar content,good melon1,green,curled up,cloudy,clear,concave,hard slip,0.697,0.46,yes2,black,curled up,dull,clear,concave,hard slip,0.774,0.376,yes3,black,curled up,cloudy,clear,concave,hard slip,0.634,0.264,yes4,green,curled up,dull,clear,concave,hard slip,0.608,0.318,yes5,pale,curled up,cloudy,clear,concave,hard slip,0.556,0.215,yes6,green,slightly curled,cloudy,clear,slightly concave,soft sticky,0.403,0.237,yes7,black,slightly curled,cloudy,slightly paste,slightly concave,soft sticky,0.481,0.149,yes8,black,slightly curled,cloudy,clear,slightly concave,hard slip,0.437,0.211,yes9,black,slightly curled,dull,slightly paste,slightly concave,hard slip,0.666,0.091,no10,green,stiff,crisp,clear,flat,soft sticky,0.243,0.267,no11,pale,stiff,crisp,fuzzy,flat,hard slip,0.245,0.057,no12,pale,curled up,cloudy,fuzzy,flat,soft sticky,0.343,0.099,no13,green,slightly curled,cloudy,slightly paste,concave,hard slip,0.639,0.161,no14,pale,slightly curled,dull,slightly paste,concave,hard slip,0.657,0.198,no15,black,slightly curled,cloudy,clear,slightly concave,soft sticky,0.36,0.37,no16,pale,curled up,cloudy,fuzzy,flat,hard slip,0.593,0.042,no17,green,curled up,dull,slightly paste,slightly concave,hard slip,0.719,0.103,no
treePlotter我就不放上来了,委屈大家看一下字典凑合下吧。
画出来的树跟书上图4.8一样:
把代码按照顺序复制到编辑器,保存下就可以运行了,记得吧treePlotter注释掉。
- 周志华《机器学习》习题4.3
- 《机器学习(周志华)》 习题4.3答案
- 《机器学习》周志华习题4.3答案
- 周志华《机器学习》第一章习题
- 周志华《机器学习》习题3.3
- 周志华《机器学习》习题6.2
- 周志华机器学习第一章习题
- 机器学习 周志华 第一章习题
- 《机器学习 (周志华)》习题7.3答案
- 《机器学习(周志华)》习题10.1 答案
- 机器学习-周志华-课后习题答案-决策树
- 《机器学习(周志华)》习题3.3答案
- 《机器学习(周志华)》 习题5.5答案
- 周志华老师的《机器学习》课后习题
- 机器学习 周志华 读书笔记 习题1.2
- 《机器学习(周志华)》习题11.1 参考答案
- 《机器学习(周志华)》P19-习题1.1
- 《机器学习(周志华)》 习题9.4参考答案
- 第2组UI组件:TextView及其子类
- 文章标题
- ceph中使用google perftool分析工具(整合)
- Intermediate Core Graphics(Swift)一
- javascript
- 周志华《机器学习》习题4.3
- [生存志] 第78节 左传汇总春秋大事
- html中使用sessionStorge存储数据
- hadoop 集群的搭建
- css总结
- 大数
- excel导出模板
- iOS 奔溃BUG统计
- 如何用Restore DataBase把数据库还原到指定的路?