决策树ID3基本代码,周志华《机器学习》练习

来源:互联网 发布:系统内存优化级别默认 编辑:程序博客网 时间:2024/05/16 07:30
# -*- coding: utf-8 -*-"""Created on Wed Dec 28 09:33:11 2016@author: ZQ"""import numpy as np#计算信息熵def Infor_Ent(data):    data_count = len(data)    labelcounts = {}    for featvec in data:        currentlabel = featvec[-1]        if currentlabel not in labelcounts.keys():            labelcounts[currentlabel] = 0        labelcounts[currentlabel] += 1    ent = 0.0    for key in labelcounts:        prob = float(labelcounts[key])/data_count        ent -= prob*np.log2(prob)    return ent#计算标签的信息增益中的被减数def Infor_Gain_Label(data,axis):    featset = set(data[:,axis])    data_count = len(data)    v_ent = 0.0    for f in featset:        f_data = []        f_count = 0        for featvec in data:            if featvec[axis] == f:                f_count += 1                f_data.append(featvec)        v_ent += f_count/data_count*Infor_Ent(f_data)          return v_ent#计算连续值中信息增益中的被减数(由于数据中浮点数是str类型,需要特殊处理)def Infor_Gain_Num(data,axis):    data_count = len(data)    feat_lsit = list(map(float,data[:,axis]))    T_list = []    maxGain_inf_T = 0    maxGain_inf = 1    v_ent = 0.0    for i in range(len(feat_lsit)-1):        T_list.append((feat_lsit[i]+feat_lsit[i+1])/2)    for t in T_list:        f_gl = []        f_gt = []        f_count_gl = 0        for featvec in data:            if float(featvec[axis]) < t:                f_count_gl += 1                f_gl.append(featvec)            else:                f_gt.append(featvec)        v_ent = f_count_gl/data_count*Infor_Ent(f_gl) + (data_count-f_count_gl)/data_count*Infor_Ent(f_gt)        if v_ent < maxGain_inf:            maxGain_inf_T = t            maxGain_inf = v_ent    return maxGain_inf_T,maxGain_inf#寻找最佳分割属性        def bestFeattosplit(data):    Featlabel = data[0][:-1]    # 用于统计非连续值的长度    label_lengh = 0    num_lengh = 0    for l in Featlabel:        if l == '密度' or l == '含糖量':                        num_lengh += 1        else:            label_lengh += 1    bestLabel = ''    bestLabel_i = -1    bestInfoGain = -1    best_T = -1    ent = Infor_Ent(data[1:])    #print(label_lengh,num_lengh)    for i in range(label_lengh):        Gain_infor = ent - Infor_Gain_Label(data[1:],i)        if Gain_infor > bestInfoGain:            bestInfoGain = Gain_infor            bestLabel = Featlabel[i]            bestLabel_i = i    for i in range(label_lengh,label_lengh+num_lengh):        T,v_ent = Infor_Gain_Num(data[1:],i)        Gain_infor = ent - v_ent        if Gain_infor > bestInfoGain:            bestInfoGain = Gain_infor            bestLabel = Featlabel[i]            best_T = T            bestLabel_i = i    return bestInfoGain,bestLabel,best_T,bestLabel_idef creatTree(data):    classList = [f[-1] for f in data[1:]]    # 类别完全相同时就返回    if classList.count(classList[0]) == len(classList):        return classList[0]    InfoGain,best_Label,best_T,best_i = bestFeattosplit(data)       tree = {best_Label:{}}    print(best_Label)    if best_T != -1:        # 大于或者小于处理        subdata_gl = []        subdata_gt = []        subdata_gl.append(data[0])        subdata_gt.append(data[0])        for f in data[1:]:            if float(f[best_i]) > best_T:                subdata_gt.append(f)            else:                subdata_gl.append(f)        #temp_label = best_Label        #InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata_gl)        subdata_gl = np.delete(subdata_gl,best_i,axis = 1)        tree[best_Label]['<'+str(best_T)] = creatTree(subdata_gl)        #InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata_gt)        subdata_gt = np.delete(subdata_gt,best_i,axis = 1)        tree[best_Label]['>'+str(best_T)] = creatTree(subdata_gt)    else:        featValues = [f[best_i] for f in data[1:]]        uniqueVals = set(featValues)        # 移除已经使用的属性,对每个属性值进行分割        for value in uniqueVals:             subdata = []            subdata.append(data[0])            for f in data[1:]:                if f[best_i] == value:                    subdata.append(f)            subdata = np.delete(subdata,best_i,axis = 1)            #temp_label = best_Label            #InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata)               tree[best_Label][value] = creatTree(subdata)    return treedef loadData():    data = []    with open('watermelon3.0.txt') as f:        for line in f.readlines():            word = line.strip().split('\t')[1:]            data.append(word)    return np.array(data)if __name__ == '__main__':    data = loadData()    #InfoGain,Label,T = bestFeattosplit(data)    tree = creatTree(data)
中间一些部分参考了《机器学习实战》中决策树这一章节的相关代码。
该代码个人觉得有些问题,希望大家多多指正
数据如下:
编号色泽根蒂敲声纹理脐部触感密度含糖量好瓜1青绿蜷缩浊响清晰凹陷硬滑0.6970.46是2乌黑蜷缩沉闷清晰凹陷硬滑0.7740.376是3乌黑蜷缩浊响清晰凹陷硬滑0.6340.264是4青绿蜷缩沉闷清晰凹陷硬滑0.6080.318是5浅白蜷缩浊响清晰凹陷硬滑0.5560.215是6青绿稍蜷浊响清晰稍凹软粘0.4030.237是7乌黑稍蜷浊响稍糊稍凹软粘0.4810.149是8乌黑稍蜷浊响清晰稍凹硬滑0.4370.211是9乌黑稍蜷沉闷稍糊稍凹硬滑0.6660.091否10青绿硬挺清脆清晰平坦软粘0.2430.267否11浅白硬挺清脆模糊平坦硬滑0.2450.057否12浅白蜷缩浊响模糊平坦软粘0.3430.099否13青绿稍蜷浊响稍糊凹陷硬滑0.6390.161否14浅白稍蜷沉闷稍糊凹陷硬滑0.6570.198否15乌黑稍蜷浊响清晰稍凹软粘0.360.37否16浅白蜷缩浊响模糊平坦硬滑0.5930.042否17青绿蜷缩沉闷稍糊稍凹硬滑0.7190.103否



1 0
原创粉丝点击