weka之ID3

来源:互联网 发布:卡通农场数据恢复 编辑:程序博客网 时间:2024/04/30 00:46
@Override    public void buildClassifier(Instances data) throws Exception     {        //检验算法能否直接处理数据        getCapabilities().testWithFail(data);        //删除带有缺失class标记的数据        data=new Instances(data);        data.deleteWithMissingClass();        makeTree(data);    }
private void makeTree(Instances data) throws Exception     {        //如果没有是数据集到达这个节点,返回叶子节点        if(data.numInstances()==0)        {            //当前分裂属性为空            m_Attribute=null;            //当前类别属性为空            m_ClassAttribute=null;            //分类比例            m_Distribution=new double[data.numClasses()];            return;        }        //保存每个属性的信息增益        double infoGains[]=new double[data.numAttributes()];        Enumeration attEnum=data.enumerateAttributes();        while (attEnum.hasMoreElements())         {            Attribute att = (Attribute) attEnum.nextElement();            infoGains[att.index()]=computeInfoGain(data, att);        }        //选取InfoGain最大的属性作为分裂属性        m_Attribute=data.attribute(Utils.maxIndex(infoGains));        System.err.println("我要打印InfoGain了");        for(int i=0;i<infoGains.length;i++)        {            System.err.print(infoGains[i]+" ");        }        System.err.println();        //当前分裂属性信息增益为0,说明是叶子节点        if(infoGains[m_Attribute.index()]==0)        {            m_Attribute=null;            m_Distribution=new double[data.numClasses()];            //下面是投票表决            Enumeration instEnum=data.enumerateInstances();            while (instEnum.hasMoreElements())             {                Instance instance = (Instance) instEnum.nextElement();                //类别取值下表对应的++                m_Distribution[(int)instance.classValue()]++;            }            //归一化成0-1之间            Utils.normalize(m_Distribution);            //类别取值            m_ClassValue=Utils.maxIndex(m_Distribution);            m_ClassAttribute=data.classAttribute();        }        else//对数据进行划分        {            Instances[] splitData=splitData(data, m_Attribute);            m_Successors=new Id3[m_Attribute.numValues()];            for(int i=0;i<m_Attribute.numValues();i++)            {                m_Successors[i]=new Id3();                m_Successors[i].makeTree(splitData[i]);            }        }    }    /*     * 计算信息增益     */    private double computeInfoGain(Instances data, Attribute att) throws Exception     {        //先计算整体的信息熵        double infoGain=computeEntropy(data);        //打印熵,用于调试        System.err.println("我要打印熵1了");        System.err.println(infoGain);        //根据属性划分数据集        Instances[] splitData=splitData(data, att);        System.err.println("下面打印出划分好的数据集");        for (int i = 0; i < splitData.length; i++)         {            System.out.println(splitData[i].numInstances());        }        for (int i = 0; i < splitData.length; i++)         {            if(splitData[i].numInstances()>0)            {                double temp1=((double)splitData[i].numInstances()/                        (double)data.numInstances());                double tempEntropy=computeEntropy(splitData[i]);                double temp2=temp1*tempEntropy;                infoGain-=temp1*temp2;                System.err.println(infoGain);            }        }        return infoGain;    }    //计算熵的函数    private double computeEntropy(Instances data)     {        double[] classCounts=new double[data.numClasses()];        Enumeration instEnum=data.enumerateInstances();        while (instEnum.hasMoreElements())         {            Instance inst = (Instance) instEnum.nextElement();            classCounts[(int)inst.classValue()]++;        }        double entropy=0;        double numInstances=data.numInstances();        for (int i = 0; i < data.numClasses(); i++)         {            if(classCounts[i]>0)            {                entropy-=((double)classCounts[i]/(double)numInstances)*Utils.log2((double)classCounts[i]/(double)numInstances);            }        }        return entropy;    }    private Instances[] splitData(Instances data, Attribute att)     {        //创建子数据集数组        Instances[] splitData=new Instances[att.numValues()];        //为数组分配空间        for (int i = 0; i < splitData.length; i++)         {            splitData[i]=new Instances(data,data.numInstances());        }        Enumeration instEnum=data.enumerateInstances();        while (instEnum.hasMoreElements())         {            Instance inst = (Instance) instEnum.nextElement();            splitData[(int)inst.value(att)].add(inst);        }        //处理空间,删除多余的没有用到的空间        for (int i = 0; i < splitData.length; i++)         {            splitData[i].compactify();        }        //返回划分好的数据集        return splitData;    }    public double[] distrbutionForInstance(Instance instance) throws Exception    {        if(instance.hasMissingValue())        {            throw new Exception("Id3"+ "算法不能处理缺失值");        }        if(m_Attribute==null)        {            return m_Distribution;        }        else        {            return m_Successors[(int)instance.value(m_Attribute)].distrbutionForInstance(instance);        }    }     public double classifyInstance(Instance instance) throws Exception      {        if(instance.hasMissingValue())        {            throw new Exception("Id3"+ "算法不能处理缺失值");        }        if(m_Attribute==null)        {            return Utils.maxIndex(m_Distribution);        }        else        {            return m_Successors[(int)instance.value(m_Attribute)].classifyInstance(instance);        }     }
原创粉丝点击