Machine Learning 学习之决策树 ID3树

来源:互联网 发布:淘宝stttyle 编辑:程序博客网 时间:2024/06/06 02:52
#coding=utf-8#每个类先事先划分属性 数据结构定义为[[[类1,类2,...],结果],...(样本n)]#ID3算法 比较最大增益 增益越大 说明该分类器商相对越小#,分类越显著(概率比越大,类别对结果的影响越大)"""上面为了简便,将特征属性离散化了,其实日志密度和好友密度都是连续的属性。对于特征属性为连续值,可以如此使用ID3算法:先将D中元素按照特征属性排序,则每两个相邻元素的中间点可以看做潜在分裂点,从第一个潜在分裂点开始,分裂D并计算两个集合的期望信息,具有最小期望信息的点称为这个属性的最佳分裂点,其信息期望作为此属性的信息期望。"""import math#import drawTclass ID3:      def __init__(self,data,labels,inpu):            self.data=data            self.labels=labels            self.inpu=inpu            #self.tree=self.getTree(data,labels)            #self.result=self.getR()            self.tree=self.getTree(data,labels,'begin')            self.getR(self.tree)            print self.result      def getTree(self,dataSet,labels,st):            re=[dataSet[i][1] for i in range(len(dataSet))]                 #1、考虑没有分类的情况len(set(re))==1 代表完成分类 无须再进行递归操作            #2、考虑分到底的情况(标签为0)len(labels)==0 这表明不知道该属于哪一类            if len(set(re))==1:                  return list(set(re))[0]            elif len(labels)==1:                  #只剩下一个标签则按当前标签分类                  dicc={}                  #print dataSet,re                  dd=[dataSet[j][0][0] for j in range(len(dataSet))]                  res=list(set(re))                  for i in range(len(list(set(dd)))):                        for k in range(len(res)):                              ma=-99                              index=0                              a=dataSet.count([[list(set(dd))[i]],res[k]])                              if a>ma:                                    ma=a                                    index=res[k]                            dicc[list(set(dd))[i]]=index                  return dicc            else:                  dat=[[dataSet[j][0][i] for j in range(len(dataSet))]\                       for i in range(len(dataSet[0][0]))]                  depth=len(dataSet[0][0])                  s0=self.calS(0,re)                  ma=-99                  index=0                  for k in range(len(dataSet[0][0])):                        s=self.calS(dat[k],re)                        s=s0-s                        if s>ma:                              ma=s                              index=k                  best=labels[index]                  tree={labels[index]:{}}                  nex=list(set(dat[index]))                  labels=labels[0:index]+labels[index+1:len(labels)]                  for i in range(len(nex)):                        data=[]                        for j in range(len(dat[index])):                              if dat[index][j]==nex[i]:                                    li=[dataSet[j][0],re[j]]                                    data.append(li)                        data=[[data[k][0][0:index]+data[k][0][index+1:len(data)],\                               data[k][1]] for k in range(len(data))]                        st=nex[i]                        subTree=self.getTree(data,labels,st)                        tree[best][nex[i]]=subTree                  #去掉friends 的列 并拆分 重新计算            return tree      def getR(self,tree):            #self.inpu            #print tree.keys()[0]            idx=self.inpu[self.getIndex(tree.keys()[0])]            #result=0            try:                  tree=tree[tree.keys()[0]][idx]                  if type(tree)==type({0:1}):                        self.getR(tree)                  else:                        self.result=tree            except:                  tree=tree[tree.keys()[0]]                  self.result=tree      def getIndex(self,ll):            for i in range(len(self.labels)):                  if ll==self.labels[i]:                        return i      def calS(self,liA,liB):            S=0            if liA==0:                  res=list(set(liB))                  for i in range(len(res)):                        pi=1.0*liB.count(res[i])/len(liB)                        #print liB.count(res[i]),pi,math.log(pi,2)                        if pi!=0:                              S+=-1.0*pi*math.log(pi,2)                   return S            else:                  #计算关联熵                  res=list(set(liB))                  liAS=list(set(liA))                  inp=[[] for i in range(len(liAS))]                  for i in range(len(liA)):                        for j in range(len(liAS)):                              if liA[i]==liAS[j]:                                    inp[j].append(i)                              for i in range(len(inp)):                        p1=1.0*len(inp[i])/len(liA)                        ansB=[liB[inp[i][j]] for j in range(len(inp[i]))]                        #print ansB,p1                        for k in range(len(res)):                              pi=1.0*ansB.count(res[k])/len(ansB)                              if pi!=0:                                    S+=-1.0*p1*pi*math.log(pi,2)                  return S#训练的数据类型如下data=[            [['sd','ss','np'],'n'],            [['sd','ls','yp'],'y'],            [['ld','ms','yp'],'y'],            [['md','ms','yp'],'y'],            [['ld','ms','yp'],'y'],            [['md','ls','np'],'y'],            [['md','ss','np'],'n'],            [['ld','ms','np'],'y'],            [['md','ss','np'],'y'],            [['sd','ss','yp'],'n']      ]inpu=['sd','ls','yp']labels=['daily','friends','photo']id3=ID3(data,labels,inpu)sx=id3.treeprint sx#print 0.881+0.4*(0.75*math.log(0.75,2)+0.25*math.log(0.25,2))
原创粉丝点击