周志华《机器学习》习题4.3

来源:互联网 发布:gta5online捏脸数据 编辑:程序博客网 时间:2024/05/23 02:01

为表4.3中数据生成一棵决策树。

代码是在《机器学习实战》的代码基础上改良的,借用了numpy, pandas之后明显简化了代码。表4.3的数据特征是离散属性和连续属性都有,问题就复杂在这里。话不多说,看代码。
先定义几个辅助函数,正常的思路是先想宏观算法,然后需要什么函数就定义什么函数。

import mathimport pandas as pdimport numpy as npfrom treePlotter import createPlotdef entropy(data):    label_values = data[data.columns[-1]]    #Returns object containing counts of unique values.    counts =  label_values.value_counts()    s = 0    for c in label_values.unique():        freq = float(counts[c])/len(label_values)         s -= freq*math.log(freq,2)    return sdef is_continuous(data,attr):    """Check if attr is a continuous attribute"""    return data[attr].dtype == 'float64'def split_points(data,attr):    """Returns Ta,Equation(4.7),p.84"""    values = np.sort(data[attr].values)    return [(x+y)/2 for x,y in zip(values[:-1],values[1:])] 

treePlotter是《实战》里的模块,用来把决策树画出来。这里决策树是用字典表示的,key可以表示树的节点或分枝,表示节点的时候是属性,表示分枝的时候是属性值。value又是一个字典或字符串,是字符串的时候表示叶,也就是标记。这里的data是pandas里的DataFrame,形式上像一个表,对表的常见操作它都可以方便的解决。命名习惯跟书上一致。

再继续看怎么计算信息增益:

def discrete_gain(data,attr):    V = data[attr].unique()    s = 0    for v in V:        data_v = data[data[attr]== v]        s += float(len(data_v))/len(data)*entropy(data_v)    return (entropy(data) - s,None)def continuous_gain(data,attr,points):    """Equation(4.8),p.84,returns the max gain along with its splitting point"""    entD = entropy(data)    #gains is a list of pairs of the form (gain,t)    gains = []    for t in points:        d_plus = data[data[attr] > t]        d_minus = data[data[attr] <= t]        gain = entD - (float(len(d_plus))/len(data)*entropy(d_plus)+float(len(d_minus))/len(data)*entropy(d_minus))        gains.append((gain,t))    return max(gains)

离散属性的信息增益一目了然,最后返回的pair中的None是为了给后面的函数判断之用,看到None就知道是离散属性了。连续属性的信息增益的计算方法是对每个划分点t都计算一下增益,然后连同t一起存到一个链表里,最后取最大的那个。
然后就是统管的信息增益函数:

def gain(data,attr):    if is_continuous(data,attr):        points = split_points(data,attr)        return continuous_gain(data,attr,points)    else:        return discrete_gain(data,attr)

还要用到一个众数函数:

def majority(label_values):    counts = label_values.value_counts()    return counts.index[0]

我们的id3终于登场了:

def id3(data):    attrs = data.columns[:-1]    #attrWithGain is of the form [(attr,(gain,t))], t is None if attr is discrete    attrWithGain = [(a,gain(data,a)) for a in attrs]     attrWithGain.sort(key = lambda tup:tup[1][0],reverse = True)    return attrWithGain[0]

它对每个属性都计算了信息增益,最后返回信息增益最大的那个属性,连带两个附加值,形式是(attr,(gain,t))。

最后造树:

def createTree(data,split_function):    label = data.columns[-1]    label_values = data[label]    #Stop when all classes are equal    if len(label_values.unique()) == 1:        return label_values.values[0]    #When no more features, or only one feature with same values, return majority    if data.shape[1] == 1 or (data.shape[1]==2 and len(data.T.ix[0].unique())==1):        return majority(label_values)    bestAttr,(g,t) = split_function(data)    #If bestAttr is discrete    if t is None:        #In this tree,a key is a node, the value is a list of trees,also a dictionary        myTree = {bestAttr:{}}        values = data[bestAttr].unique()         for v in values:            data_v = data[data[bestAttr]== v]            attrsAndLabel = data.columns.tolist()            attrsAndLabel.remove(bestAttr)            data_v = data_v[attrsAndLabel]            myTree[bestAttr][v] = createTree(data_v,split_function)        return myTree    #If bestAttr is continuous    else:        t = round(t,3)        node = bestAttr+'<='+str(t)        myTree = {node:{}}        values = ['yes','no']        for v in values:            data_v = data[data[bestAttr] <= t] if v == 'yes' else data[data[bestAttr] > t]            myTree[node][v] = createTree(data_v,split_function)        return myTree

这个我就不细说了,还得自己看。值得一提的是离散属性的下一次递归把当前的离散值删掉了,attrsAndLabel.remove(bestAttr),因为不允许这个属性出现在后续的分枝中。然而连续属性的时候,不删,允许继续出现。这个好理解,毕竟对连续属性用的是二分法,可能需要多个二分才能把情况搞清。

测试一下:

if __name__ == "__main__":    f = pd.read_csv(filepath_or_buffer = 'dataset/watermelon3.0en.csv', sep = ',')    data = f[f.columns[1:]]    tree = createTree(data,id3)    print tree    createPlot(tree)

我把原表翻译成英文了,因为中文的打印字典不显示汉字,画图的时候甚至直接不能画。

id,color,root,knock,texture,umbilical,touch,density,sugar content,good melon1,green,curled up,cloudy,clear,concave,hard slip,0.697,0.46,yes2,black,curled up,dull,clear,concave,hard slip,0.774,0.376,yes3,black,curled up,cloudy,clear,concave,hard slip,0.634,0.264,yes4,green,curled up,dull,clear,concave,hard slip,0.608,0.318,yes5,pale,curled up,cloudy,clear,concave,hard slip,0.556,0.215,yes6,green,slightly curled,cloudy,clear,slightly concave,soft sticky,0.403,0.237,yes7,black,slightly curled,cloudy,slightly paste,slightly concave,soft sticky,0.481,0.149,yes8,black,slightly curled,cloudy,clear,slightly concave,hard slip,0.437,0.211,yes9,black,slightly curled,dull,slightly paste,slightly concave,hard slip,0.666,0.091,no10,green,stiff,crisp,clear,flat,soft sticky,0.243,0.267,no11,pale,stiff,crisp,fuzzy,flat,hard slip,0.245,0.057,no12,pale,curled up,cloudy,fuzzy,flat,soft sticky,0.343,0.099,no13,green,slightly curled,cloudy,slightly paste,concave,hard slip,0.639,0.161,no14,pale,slightly curled,dull,slightly paste,concave,hard slip,0.657,0.198,no15,black,slightly curled,cloudy,clear,slightly concave,soft sticky,0.36,0.37,no16,pale,curled up,cloudy,fuzzy,flat,hard slip,0.593,0.042,no17,green,curled up,dull,slightly paste,slightly concave,hard slip,0.719,0.103,no

treePlotter我就不放上来了,委屈大家看一下字典凑合下吧。
画出来的树跟书上图4.8一样:

在西瓜数据集3.0上基于信息增益生成的决策树

把代码按照顺序复制到编辑器,保存下就可以运行了,记得吧treePlotter注释掉。

0 0
原创粉丝点击