Spark 低配版高斯朴素贝叶斯实现
2016-08-24 16:54
288 查看
Motivation
最近有项目用到Scikit-learn上的高斯朴素贝叶斯模型(简称GNB),随着数据量增大,单机上跑GNB肯定会很慢,所以打算转Spark上。然后发现MLlib并没有实现GNB,自己动手,丰衣足食~原理
GNB的原理是基于朴素贝叶斯,所以先交代朴素贝叶斯的原理。朴素贝叶斯
贝叶斯公式P(Y∣X)=P(X∣Y)∗P(Y)P(X)
利用贝叶斯公式我们就可以在已知P(X|Y)和P(Y)的情况下计算得出P(Y|X)。现在把Y看成类别,把X看成特征,那么利用贝叶斯公式,我们在已知“特征X出现的时候类别为Y的概率P(X|Y)” 和 “类别为Y的概率P(Y)”的情况下,我们就可以计算在特征X出现的情况下其类别为Y的概率P(Y|X)。
上面只考虑了只有一种特征的情况,现在考虑模型有N种特征和C种类别的情况。在给定特征X的情况下,求类别为k的概率,公式可以表示成
P(Y=k∣X1,...,XN)=P(X1,...,XN∣Y=k)∗P(Y=k)P(X1,...,XN)=P(Y=k)∗∏NiP(Xi∣Y=k)∑CjP(Y=j)∗∏NiP(Xi∣Y=j)
根据上式,我们可以计算在特征X出现的情况下其类别为Y=k的概率,对于所有的k,我们取概率最大的(最大后验)作为我们的Predict,这就是朴素贝叶斯的思路。
等等,好像有点问题,凭什么说
∏iNP(Xi∣Y=k)=P(X1,...,PN|Y=k)
对的,这就是朴素贝叶斯Naive的地方,它基于一个很强的假设——所有特征的出现是相互独立的,这也是朴素贝叶斯的局限性。
在实际应用中,还需要考虑极端情况——某个类别没有出现在样本集中 or 某个特征没有出现在某类样本集中。这个时候就需要加入平滑因子lambda去调整。
P(Y=k)=样本集中类别为k的样本个数+lambda样本集中的样本个数+类别的种类*lambda
多项式模型下:
P(X=i∣Y=k)=类别为k的样本中特征i出现的次数+lambda类别为k的样本中所有特征出现的次数+特征的种类数*lambda
伯努力模型下:
P(X=i∣Y=k)=类别为k的样本中特征i出现的次数+lambda类别为k的样本数+2*lambda
朴素贝叶斯有两种常用的模型,一种叫伯努利模型,另一种叫多项式模型。两者的区别就在于伯努利模型只考虑在一个样本中,特征是否出现了(例如某个词语是否出现了,0 or 1),而多项式模型则会考虑一个样本中特征出现的次数(例如某个词语出现的次数,一个具体的数字)。两种模型都是面向离散型的特征,如果被建模对象的特征是连续变量时,一般有两个解决方案,一是量化连续型的特征成离散型的,另一种则使用高斯朴素贝叶斯。
高斯朴素贝叶斯
高斯模型下的朴素贝叶斯与上面介绍的两种模型不同的地方是在计算P(X|Y)时,假设其服从高斯分布,这是对于连续型的特征有很友好的表现。P(X∣Y)~N(μ,σ2)P(X=a∣Y=k)=12π‾‾‾√σexp(−(a−μ)22σ2)
对于上式的均值(\mu)和方差(\sigma^{2})都是可以从样本集中统计得出。
上述利用高斯分布,我们把连续变量转变成一个概率,上一小节提到的特征是连续变量的问题解决了,其它一切照搬Naive Bayes即可。
实现
Talk is cheap,show me the code. 接下来讲讲具体实现,由于Spark MLlib中实现的向量对外API甚少,所以自己动手写了个LabeledPointclass LabeledPoint(val label: Double, val denseVector: DenseVector[Double]) extends Serializable { } object LabeledPoint extends Serializable { def apply(label: Double, denseVector: DenseVector[Double]) = { new LabeledPoint(label, denseVector) } }
高斯分布函数,给入均值和方差,生成分布函数,使用柯里化
def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = { if (variance == 0.0) { if (x == mean) 1.0 else 0.0 } else { 1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance)) } }
核心代码全览
import breeze.linalg.DenseVector import org.apache.spark.Logging import org.apache.spark.rdd.RDD import breeze.numerics._ import scala.math.Pi import xyz.qspring.spark.ml.base.LabeledPoint//注意:就是上面的LabeledPoint /** * Created by qero on 16/8/7. */ class GuassianNaiveBayes private (private val input: RDD[LabeledPoint], private val lambda: Double = 1.0) extends Serializable with Logging{ def distributiveFunc(mean: Double, variance: Double)(x: Double) : Double = { //柯里化分布函数 if (variance == 0.0) { if (x == mean) 1.0 else 0.0 } else { 1.0 / sqrt(2 * Pi * variance) * exp(- pow(x - mean, 2.0) / (2 * variance)) } } def run() = { val sampleN = input.count val grouped = input.map(point => (point.label, point.denseVector)).groupByKey().cache val classN = grouped.count //计算各类的出现概率(注意平滑因子lambda) val pi = grouped.map{case (c, a) => { val p = (a.toList.length * 1.0 + lambda) / (sampleN + lambda * classN) (c, log2(p)) //取对数,防止后期出现连乘(小数连乘容易精度丢失) }} //计算在各类情况下的各特征的均值和方差 val pji = grouped.mapValues(a => { val aSum = a.reduce((v1 ,v2) => v1 + v2) //求总数 val aSampleN = a.toArray.length //求总数 val mean = aSum / (aSampleN * 1.0) //求均值 val variance = a.map(i => { //求方差(去中心化->求和->求均值) (i - mean) :* (i - mean) }).reduce((v1 ,v2) => v1 + v2) / (aSampleN * 1.0) val paras = mean.toArray.zip(variance.toArray) paras.map(p => distributiveFunc(p._1, p._2)_) //返回(类别,[特征1的分布函数, ..., 特征n的分布函数]) }) new GuassianNBModel(pi.collectAsMap(), pji.collectAsMap()) } } class GuassianNBModel(val pi:collection.Map[Double, Double], val pji:collection.Map[Double, Array[Double => Double]]) extends Serializable { def predict(features: DenseVector[Double]) = { pji.map{case (label, models) => { val score = models.zip(features.toArray).map{case (m, v) => { log2(m(v)) //取对数,防止后期出现连乘(小数连乘容易精度丢失) }}.sum + pi(label) (score, label) //返回(log(P(F1...Fn|Label)*P(Label)), Label) }}.max //选概率最大的,其对应的Label就是模型的预测 } } object GuassianNaiveBayes extends Serializable { def fit(input: RDD[LabeledPoint]) = { new GuassianNaiveBayes(input).run() } }
测试文件,训练集train.dat
-0.017612 14.053064 0 -1.395634 4.662541 1 -0.752157 6.538620 0 -1.322371 7.152853 0 0.423363 11.054677 0 0.406704 7.067335 1 0.667394 12.741452 0 -2.460150 6.866805 1 0.569411 9.548755 0 -0.026632 10.427743 0 0.850433 6.920334 1 1.347183 13.175500 0 1.176813 3.167020 1 -1.781871 9.097953 0 -0.566606 5.749003 1 0.931635 1.589505 1 -0.024205 6.151823 1 -0.036453 2.690988 1 -0.196949 0.444165 1 1.014459 5.754399 1 1.985298 3.230619 1 -1.693453 -0.557540 1 -0.576525 11.778922 0 -0.346811 -1.678730 1 -2.124484 2.672471 1 1.217916 9.597015 0 -0.733928 9.098687 0 -3.642001 -1.618087 1 0.315985 3.523953 1 1.416614 9.619232 0 -0.386323 3.989286 1 0.556921 8.294984 1 1.224863 11.587360 0 -1.347803 -2.406051 1 -0.445678 3.297303 1 1.042222 6.105155 1 -0.618787 10.320986 0 1.152083 0.548467 1 0.828534 2.676045 1 -1.237728 10.549033 0 -0.683565 -2.166125 1 0.229456 5.921938 1 -0.959885 11.555336 0 0.492911 10.993324 0 0.184992 8.721488 0 -0.355715 10.325976 0 -0.397822 8.058397 0 0.824839 13.730343 0 1.507278 5.027866 1 0.099671 6.835839 1 -0.344008 10.717485 0 1.785928 7.718645 1 -0.918801 11.560217 0 -0.364009 4.747300 1 -0.841722 4.119083 1 0.490426 1.960539 1 -0.007194 9.075792 0 0.356107 12.447863 0 0.342578 12.281162 0 -0.810823 -1.466018 1 2.530777 6.476801 1 1.296683 11.607559 0 0.475487 12.040035 0 -0.783277 11.009725 0 0.074798 11.023650 0 -1.337472 0.468339 1 -0.102781 13.763651 0 -0.147324 2.874846 1 0.518389 9.887035 0 1.015399 7.571882 0 -1.658086 -0.027255 1 1.319944 2.171228 1 2.056216 5.019981 1 -0.851633 4.375691 1 -1.510047 6.061992 0 -1.076637 -3.181888 1 1.821096 10.283990 0 3.010150 8.401766 1 -1.099458 1.688274 1 -0.834872 -1.733869 1 -0.846637 3.849075 1
测试文件,测试集test.dat
1.400102 12.628781 0 1.752842 5.468166 1 0.078557 0.059736 1 0.089392 -0.715300 1 1.825662 12.693808 0 0.197445 9.744638 0 0.126117 0.922311 1 -0.679797 1.220530 1 0.677983 2.556666 1 0.761349 10.693862 0 -2.168791 0.143632 1 1.388610 9.341997 0 0.275221 9.543647 0 0.470575 9.332488 0 -1.889567 9.542662 0 -1.527893 12.150579 0 -1.185247 11.309318 0
测试程序
object Main extends App { override def main(args: Array[String]) { val conf = new SparkConf().setAppName("naive_bayes") val sc = new SparkContext(conf) val data = sc.textFile("data/train.dat") Logger.getRootLogger.setLevel(Level.WARN) val trainData = data.map(line => { val items = line.split("\\s+") LabeledPoint(items(items.length-1).toDouble, DenseVector(items.slice(0, items.length-1).map(_.toDouble))) }) val model = GuassianNaiveBayes.fit(trainData) val testData = sc.textFile("data/test.dat").foreach(line => { val items = line.split("\\s+") val res = model.predict(DenseVector(items.slice(0, items.length-1).map(_.toDouble))) println("true is " + items(items.length - 1) + ", predict is " + res._2 + ", score = " + pow(2, res._1)) }) } }
结果
true is 0, predict is 0.0, score = 0.007287035226911837 true is 1, predict is 1.0, score = 0.006537938765007012 true is 1, predict is 1.0, score = 0.012801368971056088 true is 1, predict is 1.0, score = 0.00970655657450153 true is 0, predict is 0.0, score = 0.00305462018270487 true is 0, predict is 0.0, score = 0.03716655013066987 true is 1, predict is 1.0, score = 0.01613160178250759 true is 1, predict is 1.0, score = 0.01548224987302873 true is 1, predict is 1.0, score = 0.01784234527209572 true is 0, predict is 0.0, score = 0.029683595996118462 true is 1, predict is 1.0, score = 0.0037636068269885714 true is 0, predict is 0.0, score = 0.011051732411404247 true is 0, predict is 0.0, score = 0.034819190499309864 true is 0, predict is 0.0, score = 0.03027279470621322 true is 0, predict is 0.0, score = 0.003400879969005375 true is 0, predict is 0.0, score = 0.0060605923826227105 true is 0, predict is 0.0, score = 0.014488715477020412
相关文章推荐
- Spark 低配版高斯朴素贝叶斯实现
- Spark 实现 朴素贝叶斯(naiveBayes)
- spark中ml机器学习库的朴素贝叶斯模型实现中文文本信息的文类预测
- 高斯平滑 高斯模糊 高斯滤波器 ( Gaussian Smoothing, Gaussian Blur, Gaussian Filter ) C++ 实现
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔
- 高斯模糊的.net实现
- Matlab实现——严格对角占优三对角方程组求解(高斯赛尔德Gauss-Seidel迭代、超松弛)
- OPENCV中混合高斯背景模型的实现
- 高斯模糊的.net实现 (摘自网络)
- 高斯图像滤波原理及其编程离散化实现方法
- 高斯滤波/高斯平滑/高斯模糊的实现及其快速算法(Gaussian Filter, Gaussian Smooth, Gaussian Blur, Fast implementation)
- 采用spark和openfire实现即时通讯系统
- 高斯模糊原理及实现
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔
- 10个重要的算法C语言实现源代码(8-9-10-----秦九昭和幂法和高斯塞德尔)
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯等等
- 10个重要的算法C语言实现源代码:拉格朗日,牛顿插值,高斯,龙贝格,牛顿迭代,牛顿-科特斯,雅克比,秦九昭,幂法,高斯塞德尔