Weka开发[19]——NaiveBayes源代码分析

来源:互联网 发布:为什么选择java 编辑:程序博客网 时间:2024/06/11 04:56
    本来不想自己写Naïve Bayes这篇源码分析的,叫不动人,没办法,只好自己写了。请读者自己看一下论文Estimating Continuous Distributions in Bayesian Classifiers。什么都不懂的读者请看Tom mitchellGenerative and Discriminative Classifiers: Naive Bayes and Logistic Regression

直接看buildClassifier函数,下面我只把我认为重要的代码列出来:

if (m_UseDiscretization) {    m_Disc = new weka.filters.supervised.attribute.Discretize();    m_Disc.setInputFormat(m_Instances);    m_Instances = weka.filters.Filter.useFilter(m_Instances, m_Disc);} else {    m_Disc = null;}

如果需要进行离散化,就进行离散化,有人问过我如何离散化,上面就是。

// Reserve space for the distributionsm_Distributions = new Estimator[m_Instances.numAttributes() - 1][m_Instances.numClasses()];m_ClassDistribution = new DiscreteEstimator(m_Instances.numClasses(),true);


          m_Distributions就是P(C)m_ClassDistribution就是P(X|C)

int attIndex = 0;Enumeration enu = m_Instances.enumerateAttributes();while (enu.hasMoreElements()) {    Attribute attribute = (Attribute) enu.nextElement();    ……    attIndex++;}


     循环对每一个特征进行处理

// If the attribute is numeric, determine the estimator // numeric precision from differences between adjacent valuesdouble numPrecision = DEFAULT_NUM_PRECISION;if (attribute.type() == Attribute.NUMERIC) {    m_Instances.sort(attribute);    if ((m_Instances.numInstances() > 0)       && !m_Instances.instance(0).isMissing(attribute)) {       double lastVal = m_Instances.instance(0).value(attribute);       double currentVal, deltaSum = 0;       int distinct = 0;       for (int i = 1; i < m_Instances.numInstances(); i++) {           Instance currentInst = m_Instances.instance(i);           if (currentInst.isMissing(attribute)) {              break;           }           currentVal = currentInst.value(attribute);           if (currentVal != lastVal) {              deltaSum += currentVal - lastVal;              lastVal = currentVal;              distinct++;           }       }       if (distinct > 0) {           numPrecision = deltaSum / distinct;       }    }}


    这一大段代码令我惊讶的是只是为了确定精度(其实它是为了可以增量式的学习),精度就是平时说的保留几位小数,不过这里不是保留多少位。先看一下代码,先对m_Instances进行排序,以前也说过排序后,这个属性的上是缺失值的样本就排到了最前面,判断如果第一个样本在这个属性上缺失值,那么就不用执行了(instances.deleteWithMissingClass();这一句已经执行了,所以不太可能发生)。接下来,得到每个样本的在当前属性的属性值currentVal,如果与前一个样本在当前属性的属性值不同,则相减,将每次差值累加至deltaSum中,最后numPrecision就是差值之和deltaSum除所有不同的属性值。

for (int j = 0; j < m_Instances.numClasses(); j++) {    switch (attribute.type()) {       case Attribute.NUMERIC:           if (m_UseKernelEstimator) {               m_Distributions[attIndex][j] = new                   KernelEstimator(numPrecision);           } else {              m_Distributions[attIndex][j] = new                   NormalEstimator(numPrecision);           }           break;       case Attribute.NOMINAL:           m_Distributions[attIndex][j] = new                   DiscreteEstimator(attribute.numValues(), true);           break;       default:           throw new Exception("Attribute type unknown to NaiveBayes");    }}

        这段代码写的看起来有点怪。判断当前属性的类型,如果是NUMERIC也就是连续属值,你可以选择KernelEstimator也可以用NormalEstimator,都用numPrecision构造参数。区别在论文中已经讲的很清楚了,两者都是用平均值和方差来计算,这也是常识了。如果是NOMINAL也就是离散值,那就用DiscreteEstimator

// Compute countsEnumeration enumInsts = m_Instances.enumerateInstances();while (enumInsts.hasMoreElements()) {    Instance instance = (Instance) enumInsts.nextElement();    updateClassifier(instance);}


    终于到了有点意义的代码,对每一个样本进行统计。updateClassifier就是根据样本更新分类器,Naïve Bayes可以是增量式的,这总是知道的吧。

public void updateClassifier(Instance instance) throws Exception {    if (!instance.classIsMissing()) {       Enumeration enumAtts = m_Instances.enumerateAttributes();       int attIndex = 0;       while (enumAtts.hasMoreElements()) {           Attribute attribute = (Attribute) enumAtts.nextElement();           if (!instance.isMissing(attribute)) {              m_Distributions[attIndex][(int) instance.classValue()]              .addValue(instance.value(attribute),instance.weight());           }           attIndex++;       }m_ClassDistribution.addValue(instance.classValue(), instance.weight());    }}


   

    进行统计,m_Distributions第一个下标就是当前属性的下标,第二个下标是类别值。最重要的函数就是addValue了,它对样本的对应类别属性值分布进行统计。最后m_ClassDistribution是对类别进行统计。

下面看一下最简单的DiscreteEstimator的构造函数:

public DiscreteEstimator(int numSymbols, boolean laplace) {    m_Counts = new double[numSymbols];    m_SumOfCounts = 0;    if (laplace) {       for (int i = 0; i < numSymbols; i++) {           m_Counts[i] = 1;       }       m_SumOfCounts = (double) numSymbols;    }}public DiscreteEstimator(int nSymbols, double fPrior) {    m_Counts = new double[nSymbols];    for (int iSymbol = 0; iSymbol < nSymbols; iSymbol++) {       m_Counts[iSymbol] = fPrior;    }    m_SumOfCounts = fPrior * (double) nSymbols;}


 

也没什么区别,第一个用Laplace,第二个不知道是什么,反正差不多。

public void addValue(double data, double weight) {    m_Counts[(int) data] += weight;    m_SumOfCounts += weight;}


 

      离散型的addValue非常简单,就是在对应的属性值上加上这个样本的权重。

      再看一下NormalEstimator的构造函数:

public NormalEstimator(double precision) {    m_Precision = precision;    // Allow at most 3 sd's within one interval    m_StandardDev = m_Precision / (2 * 3);}


        精度已经解释过了,再说一次,这里的精度不是精准到第几位,而是一个值。下面的2我想我应该是对的,可是我怕我的想法是错的,讲出来会被人笑死,如果有人有想法,讲一声,3的意思是在这个精度范围内最多能有三个标准差。

public void addValue(double data, double weight) {    if (weight == 0) {       return;    }    data = round(data);    m_SumOfWeights += weight;    m_SumOfValues += data * weight;    m_SumOfValuesSq += data * data * weight;     if (m_SumOfWeights > 0) {       m_Mean = m_SumOfValues / m_SumOfWeights;       double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean              * m_SumOfValues)              / m_SumOfWeights);       // If the stdDev ~= 0, we really have no idea of scale yet,        // so stick with the default. Otherwise...       if (stdDev > 1e-10) {           m_StandardDev = Math.max(m_Precision / (2 * 3),           // allow at most 3sd's within one interval               stdDev);       }    }}


 

        这段程序没什么好讲的,有兴趣可以去Wiki搜索Algorithms for calculating variance词条,里有Weighted incremental algorithm可能看起来更清楚一点。

        下面看一下distributionForInstance函数。

public double[] distributionForInstance(Instance instance) {    double[] probs = new double[m_NumClasses];    for (int j = 0; j < m_NumClasses; j++) {       probs[j] = m_ClassDistribution.getProbability(j);    }    Enumeration enumAtts = instance.enumerateAttributes();    int attIndex = 0;    while (enumAtts.hasMoreElements()) {       Attribute attribute = (Attribute) enumAtts.nextElement();       if (!instance.isMissing(attribute)) {           double temp, max = 0;           for (int j = 0; j < m_NumClasses; j++) {              temp = Math.max(1e-75, Math.pow(                     m_Distributions[attIndex][j]                     .getProbability(instance.value(attribute)),                      m_Instances.attribute(attIndex).weight()));              probs[j] *= temp;              if (probs[j] > max) {                  max = probs[j];              }           }       }    }    attIndex++;    // Display probabilities    Utils.normalize(probs);    return probs;}


 

首先得到类别的概率,希望你还记得公式是什么,对于每一个类别,计算在每个类别上的概率,也就是tempprobs[j] *= temp还是公式。最后看一下哪一个类别是最有可能的类别。

         DiscreteEstimatorNormalEstimatorgetProbability函数分别如下:

public double getProbability(double data) {    if (m_SumOfCounts == 0) {       return 0;    }    return (double) m_Counts[(int) data] / m_SumOfCounts;}public double getProbability(double data) {    data = round(data);    double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;    double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;     double pLower = Statistics.normalProbability(zLower);    double pUpper = Statistics.normalProbability(zUpper);    return pUpper - pLower;}


第一个没什么好讲的,直接返回,第二个我只明白+-(m_Precision/2)的意思是根据精度求它可能的最小值和最大值。

如果本科时多学点计算方法,概率统计可能今天不会这么痛苦。

 

原创粉丝点击