Weka开发[19]——NaiveBayes源代码分析
来源:互联网 发布:为什么选择java 编辑:程序博客网 时间:2024/06/11 04:56
直接看buildClassifier函数,下面我只把我认为重要的代码列出来:
if (m_UseDiscretization) { m_Disc = new weka.filters.supervised.attribute.Discretize(); m_Disc.setInputFormat(m_Instances); m_Instances = weka.filters.Filter.useFilter(m_Instances, m_Disc);} else { m_Disc = null;}
如果需要进行离散化,就进行离散化,有人问过我如何离散化,上面就是。
// Reserve space for the distributionsm_Distributions = new Estimator[m_Instances.numAttributes() - 1][m_Instances.numClasses()];m_ClassDistribution = new DiscreteEstimator(m_Instances.numClasses(),true);
m_Distributions就是P(C),m_ClassDistribution就是P(X|C)。
int attIndex = 0;Enumeration enu = m_Instances.enumerateAttributes();while (enu.hasMoreElements()) { Attribute attribute = (Attribute) enu.nextElement(); …… attIndex++;}
循环对每一个特征进行处理
// If the attribute is numeric, determine the estimator // numeric precision from differences between adjacent valuesdouble numPrecision = DEFAULT_NUM_PRECISION;if (attribute.type() == Attribute.NUMERIC) { m_Instances.sort(attribute); if ((m_Instances.numInstances() > 0) && !m_Instances.instance(0).isMissing(attribute)) { double lastVal = m_Instances.instance(0).value(attribute); double currentVal, deltaSum = 0; int distinct = 0; for (int i = 1; i < m_Instances.numInstances(); i++) { Instance currentInst = m_Instances.instance(i); if (currentInst.isMissing(attribute)) { break; } currentVal = currentInst.value(attribute); if (currentVal != lastVal) { deltaSum += currentVal - lastVal; lastVal = currentVal; distinct++; } } if (distinct > 0) { numPrecision = deltaSum / distinct; } }}
这一大段代码令我惊讶的是只是为了确定精度(其实它是为了可以增量式的学习),精度就是平时说的保留几位小数,不过这里不是保留多少位。先看一下代码,先对m_Instances进行排序,以前也说过排序后,这个属性的上是缺失值的样本就排到了最前面,判断如果第一个样本在这个属性上缺失值,那么就不用执行了(instances.deleteWithMissingClass();这一句已经执行了,所以不太可能发生)。接下来,得到每个样本的在当前属性的属性值currentVal,如果与前一个样本在当前属性的属性值不同,则相减,将每次差值累加至deltaSum中,最后numPrecision就是差值之和deltaSum除所有不同的属性值。
for (int j = 0; j < m_Instances.numClasses(); j++) { switch (attribute.type()) { case Attribute.NUMERIC: if (m_UseKernelEstimator) { m_Distributions[attIndex][j] = new KernelEstimator(numPrecision); } else { m_Distributions[attIndex][j] = new NormalEstimator(numPrecision); } break; case Attribute.NOMINAL: m_Distributions[attIndex][j] = new DiscreteEstimator(attribute.numValues(), true); break; default: throw new Exception("Attribute type unknown to NaiveBayes"); }}
这段代码写的看起来有点怪。判断当前属性的类型,如果是NUMERIC也就是连续属值,你可以选择KernelEstimator也可以用NormalEstimator,都用numPrecision构造参数。区别在论文中已经讲的很清楚了,两者都是用平均值和方差来计算,这也是常识了。如果是NOMINAL也就是离散值,那就用DiscreteEstimator。
// Compute countsEnumeration enumInsts = m_Instances.enumerateInstances();while (enumInsts.hasMoreElements()) { Instance instance = (Instance) enumInsts.nextElement(); updateClassifier(instance);}
终于到了有点意义的代码,对每一个样本进行统计。updateClassifier就是根据样本更新分类器,Naïve Bayes可以是增量式的,这总是知道的吧。
public void updateClassifier(Instance instance) throws Exception { if (!instance.classIsMissing()) { Enumeration enumAtts = m_Instances.enumerateAttributes(); int attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (!instance.isMissing(attribute)) { m_Distributions[attIndex][(int) instance.classValue()] .addValue(instance.value(attribute),instance.weight()); } attIndex++; }m_ClassDistribution.addValue(instance.classValue(), instance.weight()); }}
进行统计,m_Distributions第一个下标就是当前属性的下标,第二个下标是类别值。最重要的函数就是addValue了,它对样本的对应类别属性值分布进行统计。最后m_ClassDistribution是对类别进行统计。
下面看一下最简单的DiscreteEstimator的构造函数:
public DiscreteEstimator(int numSymbols, boolean laplace) { m_Counts = new double[numSymbols]; m_SumOfCounts = 0; if (laplace) { for (int i = 0; i < numSymbols; i++) { m_Counts[i] = 1; } m_SumOfCounts = (double) numSymbols; }}public DiscreteEstimator(int nSymbols, double fPrior) { m_Counts = new double[nSymbols]; for (int iSymbol = 0; iSymbol < nSymbols; iSymbol++) { m_Counts[iSymbol] = fPrior; } m_SumOfCounts = fPrior * (double) nSymbols;}
也没什么区别,第一个用Laplace,第二个不知道是什么,反正差不多。
public void addValue(double data, double weight) { m_Counts[(int) data] += weight; m_SumOfCounts += weight;}
离散型的addValue非常简单,就是在对应的属性值上加上这个样本的权重。
再看一下NormalEstimator的构造函数:
public NormalEstimator(double precision) { m_Precision = precision; // Allow at most 3 sd's within one interval m_StandardDev = m_Precision / (2 * 3);}
精度已经解释过了,再说一次,这里的精度不是精准到第几位,而是一个值。下面的2我想我应该是对的,可是我怕我的想法是错的,讲出来会被人笑死,如果有人有想法,讲一声,3的意思是在这个精度范围内最多能有三个标准差。
public void addValue(double data, double weight) { if (weight == 0) { return; } data = round(data); m_SumOfWeights += weight; m_SumOfValues += data * weight; m_SumOfValuesSq += data * data * weight; if (m_SumOfWeights > 0) { m_Mean = m_SumOfValues / m_SumOfWeights; double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean * m_SumOfValues) / m_SumOfWeights); // If the stdDev ~= 0, we really have no idea of scale yet, // so stick with the default. Otherwise... if (stdDev > 1e-10) { m_StandardDev = Math.max(m_Precision / (2 * 3), // allow at most 3sd's within one interval stdDev); } }}
这段程序没什么好讲的,有兴趣可以去Wiki搜索Algorithms for calculating variance词条,里有Weighted incremental algorithm可能看起来更清楚一点。
下面看一下distributionForInstance函数。
public double[] distributionForInstance(Instance instance) { double[] probs = new double[m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { probs[j] = m_ClassDistribution.getProbability(j); } Enumeration enumAtts = instance.enumerateAttributes(); int attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = (Attribute) enumAtts.nextElement(); if (!instance.isMissing(attribute)) { double temp, max = 0; for (int j = 0; j < m_NumClasses; j++) { temp = Math.max(1e-75, Math.pow( m_Distributions[attIndex][j] .getProbability(instance.value(attribute)), m_Instances.attribute(attIndex).weight())); probs[j] *= temp; if (probs[j] > max) { max = probs[j]; } } } } attIndex++; // Display probabilities Utils.normalize(probs); return probs;}
首先得到类别的概率,希望你还记得公式是什么,对于每一个类别,计算在每个类别上的概率,也就是temp,probs[j] *= temp还是公式。最后看一下哪一个类别是最有可能的类别。
DiscreteEstimator和NormalEstimator的getProbability函数分别如下:
public double getProbability(double data) { if (m_SumOfCounts == 0) { return 0; } return (double) m_Counts[(int) data] / m_SumOfCounts;}public double getProbability(double data) { data = round(data); double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev; double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev; double pLower = Statistics.normalProbability(zLower); double pUpper = Statistics.normalProbability(zUpper); return pUpper - pLower;}
第一个没什么好讲的,直接返回,第二个我只明白+-(m_Precision/2)的意思是根据精度求它可能的最小值和最大值。
如果本科时多学点计算方法,概率统计可能今天不会这么痛苦。
- Weka开发[19]——NaiveBayes源代码分析
- Weka开发[20]——IB1源代码分析
- Weka开发——REPTree源代码分析
- Weka开发[11]—J48源代码介绍
- weka分类器-NaiveBayes
- Weka开发[14]-AdaBoost源代码介绍
- Weka开发[16]-OneR源代码介绍
- Weka开发[-1]——在你的代码中使用Weka
- Weka开发[-1]——在你的代码中使用Weka
- Weka开发[-1]——在你的代码中使用Weka
- Weka开发[9]—KMeans源码介绍
- Weka开发[10]—NBTree源码介绍
- 文本分类——NaiveBayes
- weka源代码分析-总述
- Weka开发[15]-ZeroR源代码介绍(入门篇)
- Weka开发[0]-导入Weka包
- Weka开发[0]-导入Weka包
- weka中Multiple Perceptron源代码分析
- 键盘事件导致view上移下移
- linux 设置ip地址信息
- MYSQL安装到最后一步服务器失败的操作
- hdoj_1162Eddy's picture(最小生成树)
- VS2010 Express 使用小知识
- Weka开发[19]——NaiveBayes源代码分析
- jQuery,javascript获得网页的高度和宽度 .
- 使用Xmanager在CentOs5.5 安装oracle11g r2
- c文件与h文件及包含关系
- 线程编程中用到HttpContext.Current的方法封装
- C/C++堆、栈及静态数据区详解
- SWIFT API 使用文档
- 接口
- 如果在外网的情况下使用Xmanager无法连接到Linux上