# -*- coding: utf-8 -*-"""Created on Wed Dec 28 09:33:11 2016@author: ZQ"""import numpy as np#计算信息熵def Infor_Ent(data): data_count = len(data) labelcounts = {} for featvec in data: currentlabel = featvec[-1] if currentlabel not in labelcounts.keys(): labelcounts[currentlabel] = 0 labelcounts[currentlabel] += 1 ent = 0.0 for key in labelcounts: prob = float(labelcounts[key])/data_count ent -= prob*np.log2(prob) return ent#计算标签的信息增益中的被减数def Infor_Gain_Label(data,axis): featset = set(data[:,axis]) data_count = len(data) v_ent = 0.0 for f in featset: f_data = [] f_count = 0 for featvec in data: if featvec[axis] == f: f_count += 1 f_data.append(featvec) v_ent += f_count/data_count*Infor_Ent(f_data) return v_ent#计算连续值中信息增益中的被减数(由于数据中浮点数是str类型,需要特殊处理)def Infor_Gain_Num(data,axis): data_count = len(data) feat_lsit = list(map(float,data[:,axis])) T_list = [] maxGain_inf_T = 0 maxGain_inf = 1 v_ent = 0.0 for i in range(len(feat_lsit)-1): T_list.append((feat_lsit[i]+feat_lsit[i+1])/2) for t in T_list: f_gl = [] f_gt = [] f_count_gl = 0 for featvec in data: if float(featvec[axis]) < t: f_count_gl += 1 f_gl.append(featvec) else: f_gt.append(featvec) v_ent = f_count_gl/data_count*Infor_Ent(f_gl) + (data_count-f_count_gl)/data_count*Infor_Ent(f_gt) if v_ent < maxGain_inf: maxGain_inf_T = t maxGain_inf = v_ent return maxGain_inf_T,maxGain_inf#寻找最佳分割属性 def bestFeattosplit(data): Featlabel = data[0][:-1] # 用于统计非连续值的长度 label_lengh = 0 num_lengh = 0 for l in Featlabel: if l == '密度' or l == '含糖量': num_lengh += 1 else: label_lengh += 1 bestLabel = '' bestLabel_i = -1 bestInfoGain = -1 best_T = -1 ent = Infor_Ent(data[1:]) #print(label_lengh,num_lengh) for i in range(label_lengh): Gain_infor = ent - Infor_Gain_Label(data[1:],i) if Gain_infor > bestInfoGain: bestInfoGain = Gain_infor bestLabel = Featlabel[i] bestLabel_i = i for i in range(label_lengh,label_lengh+num_lengh): T,v_ent = Infor_Gain_Num(data[1:],i) Gain_infor = ent - v_ent if Gain_infor > bestInfoGain: bestInfoGain = Gain_infor bestLabel = Featlabel[i] best_T = T bestLabel_i = i return bestInfoGain,bestLabel,best_T,bestLabel_idef creatTree(data): classList = [f[-1] for f in data[1:]] # 类别完全相同时就返回 if classList.count(classList[0]) == len(classList): return classList[0] InfoGain,best_Label,best_T,best_i = bestFeattosplit(data) tree = {best_Label:{}} print(best_Label) if best_T != -1: # 大于或者小于处理 subdata_gl = [] subdata_gt = [] subdata_gl.append(data[0]) subdata_gt.append(data[0]) for f in data[1:]: if float(f[best_i]) > best_T: subdata_gt.append(f) else: subdata_gl.append(f) #temp_label = best_Label #InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata_gl) subdata_gl = np.delete(subdata_gl,best_i,axis = 1) tree[best_Label]['<'+str(best_T)] = creatTree(subdata_gl) #InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata_gt) subdata_gt = np.delete(subdata_gt,best_i,axis = 1) tree[best_Label]['>'+str(best_T)] = creatTree(subdata_gt) else: featValues = [f[best_i] for f in data[1:]] uniqueVals = set(featValues) # 移除已经使用的属性,对每个属性值进行分割 for value in uniqueVals: subdata = [] subdata.append(data[0]) for f in data[1:]: if f[best_i] == value: subdata.append(f) subdata = np.delete(subdata,best_i,axis = 1) #temp_label = best_Label #InfoGain,best_Label,best_T,best_i = bestFeattosplit(subdata) tree[best_Label][value] = creatTree(subdata) return treedef loadData(): data = [] with open('watermelon3.0.txt') as f: for line in f.readlines(): word = line.strip().split('\t')[1:] data.append(word) return np.array(data)if __name__ == '__main__': data = loadData() #InfoGain,Label,T = bestFeattosplit(data) tree = creatTree(data)
中间一些部分参考了《机器学习实战》中决策树这一章节的相关代码。
该代码个人觉得有些问题,希望大家多多指正
数据如下:
编号色泽根蒂敲声纹理脐部触感密度含糖量好瓜1青绿蜷缩浊响清晰凹陷硬滑0.6970.46是2乌黑蜷缩沉闷清晰凹陷硬滑0.7740.376是3乌黑蜷缩浊响清晰凹陷硬滑0.6340.264是4青绿蜷缩沉闷清晰凹陷硬滑0.6080.318是5浅白蜷缩浊响清晰凹陷硬滑0.5560.215是6青绿稍蜷浊响清晰稍凹软粘0.4030.237是7乌黑稍蜷浊响稍糊稍凹软粘0.4810.149是8乌黑稍蜷浊响清晰稍凹硬滑0.4370.211是9乌黑稍蜷沉闷稍糊稍凹硬滑0.6660.091否10青绿硬挺清脆清晰平坦软粘0.2430.267否11浅白硬挺清脆模糊平坦硬滑0.2450.057否12浅白蜷缩浊响模糊平坦软粘0.3430.099否13青绿稍蜷浊响稍糊凹陷硬滑0.6390.161否14浅白稍蜷沉闷稍糊凹陷硬滑0.6570.198否15乌黑稍蜷浊响清晰稍凹软粘0.360.37否16浅白蜷缩浊响模糊平坦硬滑0.5930.042否17青绿蜷缩沉闷稍糊稍凹硬滑0.7190.103否
1 0