weka之NB算法

来源:互联网 发布:soa软件架构 编辑:程序博客网 时间:2024/05/17 04:51
    @Override    public void buildClassifier(Instances data) throws Exception     {        //检测分类器能否处理数据        getCapabilities().testWithFail(data);        //删除具有类别缺失值的实例        data=new Instances(data);        data.deleteWithMissingClass();        //保存类别的数量        m_NumClasses=data.numClasses();        //复制训练集        m_Instances=new Instances(data);        //如果指定,就对数据进行离散化        if(m_UseDiscretization)        {            m_Disc=new weka.filters.supervised.attribute.Discretize();            m_Disc.setInputFormat(data);            m_Instances=weka.filters.Filter.useFilter(m_Instances, m_Disc);        }        else        {            m_Disc=null;        }        //为概率分布预留空间        //类别条件概率分布P(X|Y)        m_Distributions=new Estimator[m_Instances.numAttributes()-1][m_Instances.numClasses()];        //类别分布P(Y)        m_ClassDistribution=new DiscreteEstimator(m_Instances.numClasses(), true);        int attIndex=0;        Enumeration enumeration=m_Instances.enumerateAttributes();        //循环处理每一个属性        while(enumeration.hasMoreElements())        {            Attribute attribute=(Attribute) enumeration.nextElement();            //如果属性是数值型,根据相邻值之间的差异,测定估计器数值精度            double numPrecision=DEFAULT_NUM_PRECISION;            if(attribute.type()==Attribute.NUMERIC)            {                //根据当前属性的值对数据集排序                m_Instances.sort(attribute);                //排序之后,当前属性缺失值的实例就排到最前                //这样,判断第一个样本是否有缺失值,就知道整体样本是否有缺失值                //如果有,就没有必要执行if后面的代码块                if((m_Instances.numInstances()>0) && !m_Instances.instance(0).isMissing(attribute))                {                    //lastVal为后一个实例的当前属性值                    double lastVal=m_Instances.instance(0).value(attribute);                    //currentVal,为每个实例的当前属性值,deltaSum为差值                    double currentVal,deltaSum=0;                    //distinct为当前属性取不同值的数量                    int distinct=0;                    for(int i=1;i<m_Instances.numInstances();i++)                    {                        Instance currentInst=m_Instances.instance(i);                        if(currentInst.isMissing(attribute))                        {                            break;                        }                        currentVal=currentInst.value(attribute);                        //如果当前值与最后值不相等,则相减并将差值累加到deltaSum                        if(currentVal!=lastVal)                        {                            deltaSum+=currentVal-lastVal;                            lastVal=currentVal;                            distinct++;                        }                    }                    //最终的numPrecision就是deltaSum/distinct                    if(distinct>0)                    {                        numPrecision=deltaSum/distinct;                    }                }            }            //循环处理每一个类别标签            for(int j=0;j<m_Instances.numClasses();j++)            {                //判断当前属性的类型                switch(attribute.type())                {                //如果为连续的数值型属性,根据是否使用核估计器的选项,选择构建Kernelstimator对象还是NormalEstimator对象                //两者的构造函数都是使用numPrecision作为参数                case Attribute.NUMERIC:                    if(m_UseKernelEstimator)                    {                        m_Distributions[attIndex][j]=new KernelEstimator(numPrecision);                    }                    else                    {                        m_Distributions[attIndex][j]=new NormalEstimator(numPrecision);                    }                    break;                case Attribute.NOMINAL:                    m_Distributions[attIndex][j]=new DiscreteEstimator(attribute.numValues(), true);                    break;                default:                    throw new Exception("Attribute type unkown to my NB");                }            }            attIndex++;        }        //统计每一个实例        Enumeration enumInsts=m_Instances.enumerateInstances();        while (enumInsts.hasMoreElements())         {            Instance instance=(Instance) enumInsts.nextElement();            //调用updateClassifier方法,用实例更新分离器            updateClassifier(instance);        }        //节省空间        m_Instances=new Instances(m_Instances,0);    }    public void updateClassifier(Instance instance)     {        if(!instance.classIsMissing())        {            Enumeration enumAtts=m_Instances.enumerateAttributes();            int attIndex=0;            //循环处理没一个属性            while (enumAtts.hasMoreElements())             {                Attribute attribute = (Attribute) enumAtts.nextElement();                if(!instance.isMissing(attribute))                {                    //m_Distributons第一个下标记为当亲属性下标记,第二个下标为类别值                    //统计样本实例对应类别属性值的分布                    //调用Estimator的AddValue方法将新数据值加入到当前评估器中                    m_Distributions[attIndex][(int)instance.classValue()].addValue(instance.value(attribute),                            instance.weight());                }                attIndex++;            }            //统计类别分布            m_ClassDistribution.addValue(instance.classValue(), instance.weight());        }    }    public double[] distributionForInstance(Instance instance) throws Exception    {        //如果使用useSupervisedDiscretization选项,就对实例进行离散化        if(m_UseDiscretization)        {            m_Disc.input(instance);            instance=m_Disc.output();        }        //类别的概率P(Y)        double probs[]=new double[m_NumClasses];        //循环得到每个类别的概率        for(int j=0;j<m_NumClasses;j++)        {            probs[j]=m_ClassDistribution.getProbability(j);        }        Enumeration enumAtts=instance.enumerateAttributes();        int attIndex=0;        //循环处理每个属性        while(enumAtts.hasMoreElements())        {            Attribute attribute=(Attribute) enumAtts.nextElement();            if(!instance.isMissing(attribute))            {                //temp为临时概率,max为当前最大概率                double temp,max=0;                for (int j = 0; j < m_NumClasses; j++)                {                    //计算每个类别的条件概率P(X|Y)                    temp=Math.max(1e-75, Math.pow(m_Distributions[attIndex][j].getProbability(instance.value(attribute)),                             m_Instances.attribute(attIndex).weight()));                    probs[j]*=temp;                    //更新最大概率值                    if(probs[j]>max)                    {                        max=probs[j];                    }                    if(Double.isNaN(probs[j]))                    {                        throw new Exception(                                "Nan returned from estimator for atrribute "+                                attribute.name()+":\n"+                                m_Distributions[attIndex][j].toString());                    }                }                if(max>0 && max<1e-75)                {                    //防止概率下溢的危险                    for(int j=0;j<m_NumClasses;j++)                    {                        probs[j]*=1e75;                    }                }            }            attIndex++;        }        //概率规范化        Utils.normalize(probs);        return probs;    }
0 0
原创粉丝点击