Motivation
最近有项目用到Scikit-learn上的高斯朴素贝叶斯模型(简称GNB),随着数据量增大,单机上跑GNB肯定会很慢,所以打算转Spark上。然后发现MLlib并没有实现GNB,自己动手,丰衣足食~
原理
GNB的原理是基于朴素贝叶斯,所以先交代朴素贝叶斯的原理。
朴素贝叶斯
贝叶斯公式
P(Y∣X)=P(X∣Y)∗P(Y)P(X)
利用贝叶斯公式我们就可以在已知P(X|Y)和P(Y)的情况下计算得出P(Y|X)。现在把Y看成类别,把X看成特征,那么利用贝叶斯公式,我们在已知“特征X出现的时候类别为Y的概率P(X|Y)” 和 “类别为Y的概率P(Y)”的情况下,我们就可以计算在特征X出现的情况下其类别为Y的概率P(Y|X)。
上面只考虑了只有一种特征的情况,现在考虑模型有N种特征和C种类别的情况。在给定特征X的情况下,求类别为k的概率,公式可以表示成
P(Y=k∣X1,...,XN)=P(X1,...,XN∣Y=k)∗P(Y=k)P(X1,...,XN)=P(Y=k)∗∏NiP(Xi∣Y=k)∑CjP(Y=j)∗∏NiP(Xi∣Y=j)
根据上式,我们可以计算在特征X出现的情况下其类别为Y=k的概率,对于所有的k,我们取概率最大的(最大后验)作为我们的Predict,这就是朴素贝叶斯的思路。
等等,好像有点问题,凭什么说
∏iNP(Xi∣Y=k)=P(X1,...,PN|Y=k)
对的,这就是朴素贝叶斯Naive的地方,它基于一个很强的假设——所有特征的出现是相互独立的,这也是朴素贝叶斯的局限性。
在实际应用中,还需要考虑极端情况——某个类别没有出现在样本集中 or 某个特征没有出现在某类样本集中。这个时候就需要加入平滑因子lambda去调整。
P(Y=k)=样本集中类别为k的样本个数+lambda样本集中的样本个数+类别的种类*lambda
多项式模型下:
P(X=i∣Y=k)=类别为k的样本中特征i出现的次数+lambda类别为k的样本中所有特征出现的次数+特征的种类数*lambda
伯努力模型下:
P(X=i∣Y=k)=类别为k的样本中特征i出现的次数+lambda类别为k的样本数+2*lambda
朴素贝叶斯有两种常用的模型,一种叫伯努利模型,另一种叫多项式模型。两者的区别就在于伯努利模型只考虑在一个样本中,特征是否出现了(例如某个词语是否出现了,0 or 1),而多项式模型则会考虑一个样本中特征出现的次数(例如某个词语出现的次数,一个具体的数字)。两种模型都是面向离散型的特征,如果被建模对象的特征是连续变量时,一般有两个解决方案,一是量化连续型的特征成离散型的,另一种则使用高斯朴素贝叶斯。
高斯朴素贝叶斯
高斯模型下的朴素贝叶斯与上面介绍的两种模型不同的地方是在计算P(X|Y)时,假设其服从高斯分布,这是对于连续型的特征有很友好的表现。
P(X∣Y)~N(μ,σ2)P(X=a∣Y=k)=12π−−√σexp(−(a−μ)22σ2)
对于上式的均值(\mu)和方差(\sigma^{2})都是可以从样本集中统计得出。
上述利用高斯分布,我们把连续变量转变成一个概率,上一小节提到的特征是连续变量的问题解决了,其它一切照搬Naive Bayes即可。
实现
Talk is cheap,show me the code. 接下来讲讲具体实现,由于Spark MLlib中实现的向量对外API甚少,所以自己动手写了个LabeledPoint
class LabeledPoint(val label: Double, val denseVector: DenseVector[Double]) extends Serializable {}object LabeledPoint extends Serializable { def apply(label: Double, denseVector: DenseVector[Double]) = { new LabeledPoint(label, denseVector) }}
高斯分布函数,给入均值和方差,生成分布函数,使用柯里化
def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = { if (variance == 0.0) { if (x == mean) 1.0 else 0.0 } else { 1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance)) }}
核心代码全览
import breeze.linalg.DenseVectorimport org.apache.spark.Loggingimport org.apache.spark.rdd.RDDimport breeze.numerics._import scala.math.Piimport xyz.qspring.spark.ml.base.LabeledPoint/** * Created by qero on 16/8/7. */class GuassianNaiveBayes private (private val input: RDD[LabeledPoint], private val lambda: Double = 1.0) extends Serializable with Logging{ def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = { if (variance == 0.0) { if (x == mean) 1.0 else 0.0 } else { 1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance)) } } def run() = { val sampleN = input.count val grouped = input.map(point => (point.label, point.denseVector)).groupByKey().cache val classN = grouped.count val pi = grouped.map{case (c, a) => { val p = (a.toList.length * 1.0 + lambda) / (sampleN + lambda * classN) (c, log2(p)) }} val pji = grouped.mapValues(a => { val aSum = a.reduce((v1 ,v2) => v1 + v2) val aSampleN = a.toArray.length val mean = aSum / (aSampleN * 1.0) val variance = a.map(i => { (i - mean) :* (i - mean) }).reduce((v1 ,v2) => v1 + v2) / (aSampleN * 1.0) val paras = mean.toArray.zip(variance.toArray) paras.map(p => distributiveFunc(p._1, p._2)_) }) new GuassianNBModel(pi.collectAsMap(), pji.collectAsMap()) }}class GuassianNBModel(val pi:collection.Map[Double, Double], val pji:collection.Map[Double, Array[Double => Double]]) extends Serializable { def predict(features: DenseVector[Double]) = { pji.map{case (label, models) => { val score = models.zip(features.toArray).map{case (m, v) => { log2(m(v)) }}.sum + pi(label) (score, label) }}.max }}object GuassianNaiveBayes extends Serializable { def fit(input: RDD[LabeledPoint]) = { new GuassianNaiveBayes(input).run() }}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
测试文件,训练集train.dat
-0.017612 14.053064 0-1.395634 4.662541 1-0.752157 6.538620 0-1.322371 7.152853 00.423363 11.054677 00.406704 7.067335 10.667394 12.741452 0-2.460150 6.866805 10.569411 9.548755 0-0.026632 10.427743 00.850433 6.920334 1 1.347183 13.175500 0 1.176813 3.167020 1 -1.781871 9.097953 0 -0.566606 5.749003 1 0.931635 1.589505 1 -0.024205 6.151823 1 -0.036453 2.690988 1 -0.196949 0.444165 1 1.014459 5.754399 1 1.985298 3.230619 1 -1.693453 -0.557540 1 -0.576525 11.778922 0 -0.346811 -1.678730 1 -2.124484 2.672471 1 1.217916 9.597015 0 -0.733928 9.098687 0 -3.642001 -1.618087 1 0.315985 3.523953 1 1.416614 9.619232 0-0.386323 3.989286 10.556921 8.294984 11.224863 11.587360 0-1.347803 -2.406051 1-0.445678 3.297303 11.042222 6.105155 1-0.618787 10.320986 01.152083 0.548467 10.828534 2.676045 1-1.237728 10.549033 0-0.683565 -2.166125 1 0.229456 5.921938 1 -0.959885 11.555336 0 0.492911 10.993324 0 0.184992 8.721488 0 -0.355715 10.325976 0 -0.397822 8.058397 0 0.824839 13.730343 0 1.507278 5.027866 1 0.099671 6.835839 1 -0.344008 10.717485 0 1.785928 7.718645 1 -0.918801 11.560217 0 -0.364009 4.747300 1 -0.841722 4.119083 1 0.490426 1.960539 1 -0.007194 9.075792 0 0.356107 12.447863 0 0.342578 12.281162 0 -0.810823 -1.466018 1 2.530777 6.476801 1 1.296683 11.607559 0 0.475487 12.040035 0 -0.783277 11.009725 0 0.074798 11.023650 0 -1.337472 0.468339 1 -0.102781 13.763651 0 -0.147324 2.874846 1 0.518389 9.887035 0 1.015399 7.571882 0 -1.658086 -0.027255 11.319944 2.171228 12.056216 5.019981 1-0.851633 4.375691 1-1.510047 6.061992 0-1.076637 -3.181888 11.821096 10.283990 03.010150 8.401766 1-1.099458 1.688274 1-0.834872 -1.733869 1-0.846637 3.849075 1
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
测试文件,测试集test.dat
1.400102 12.628781 01.752842 5.468166 10.078557 0.059736 10.089392 -0.715300 11.825662 12.693808 00.197445 9.744638 00.126117 0.922311 1-0.679797 1.220530 10.677983 2.556666 10.761349 10.693862 0-2.168791 0.143632 11.388610 9.341997 00.275221 9.543647 00.470575 9.332488 0-1.889567 9.542662 0-1.527893 12.150579 0-1.185247 11.309318 0
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
测试程序
object Main extends App { override def main(args: Array[String]) { val conf = new SparkConf().setAppName("naive_bayes") val sc = new SparkContext(conf) val data = sc.textFile("data/train.dat") Logger.getRootLogger.setLevel(Level.WARN) val trainData = data.map(line => { val items = line.split("\\s+") LabeledPoint(items(items.length-1).toDouble, DenseVector(items.slice(0, items.length-1).map(_.toDouble))) }) val model = GuassianNaiveBayes.fit(trainData) val testData = sc.textFile("data/test.dat").foreach(line => { val items = line.split("\\s+") val res = model.predict(DenseVector(items.slice(0, items.length-1).map(_.toDouble))) println("true is " + items(items.length - 1) + ", predict is " + res._2 + ", score = " + pow(2, res._1)) }) }}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
结果
true is 0, predict is 0.0, score = 0.007287035226911837true is 1, predict is 1.0, score = 0.006537938765007012true is 1, predict is 1.0, score = 0.012801368971056088true is 1, predict is 1.0, score = 0.00970655657450153true is 0, predict is 0.0, score = 0.00305462018270487true is 0, predict is 0.0, score = 0.03716655013066987true is 1, predict is 1.0, score = 0.01613160178250759true is 1, predict is 1.0, score = 0.01548224987302873true is 1, predict is 1.0, score = 0.01784234527209572true is 0, predict is 0.0, score = 0.029683595996118462true is 1, predict is 1.0, score = 0.0037636068269885714true is 0, predict is 0.0, score = 0.011051732411404247true is 0, predict is 0.0, score = 0.034819190499309864true is 0, predict is 0.0, score = 0.03027279470621322true is 0, predict is 0.0, score = 0.003400879969005375true is 0, predict is 0.0, score = 0.0060605923826227105true is 0, predict is 0.0, score = 0.014488715477020412