利用朴素贝叶斯算法识别垃圾邮件

来源:互联网 发布:淘宝助理怎么导入宝贝 编辑:程序博客网 时间:2024/06/06 01:11

转载自:http://blog.csdn.net/wowcplusplus/article/details/25190809

朴素贝叶斯算法是被工业界广泛应用的机器学习算法,它有较强的数学理论基础,在一些典型的应用中效果显著。朴素贝叶斯算法基于概率论的贝叶斯理论。该理论的核心公式如下:

                                                      

式中,表示某种分类,则表示已知的情况下类型为的条件概率。我们求出各个类别下的,然后比较它们的大小,以概率最大的作为最后的类别,以此达到分类的目的。下面我们来看如何计算这些条件概率。

        已知,则。朴素贝叶斯假定互为独立变量,则。而为指示函数,存在则为1,不存在则为0),都可用训练数据直接统计得出。故可依据上述分析求得的大小。又由于对所有类别都是固定大小,所以比较条件概率的大小等同于比较的大小。这就是朴素贝叶斯的数学原理。

        下面,我们以朴素贝叶斯的一个典型应用——过滤垃圾邮件来展示该算法的python实现。

        现在,我们有25件垃圾邮件和25件正常邮件,如何使用这些邮件作为训练数据得到过滤垃圾邮件的朴素贝叶斯模型呢?首先,我们用各邮件的词组成词向量表示在邮件中出现过的词,再计算孰大孰小即可决定这是封正常邮件(ham)还是封垃圾邮件(spam)。

        第一步,我们将邮件转换为numpy的array形式,使用如下函数:

[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. def file2array(filename):  
  2.     fileReader=open(filename,'r').read()  
  3.     listOfWord=re.split(r'\W*',fileReader)  
  4.     fileArray=[word.lower() for word in listOfWord if len(word)>3]  
  5.     return fileArray  
        然后我们将所有邮件合在一个array里面,并在其中随机选取5封作为测试集:

[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. def getAllInfo():  
  2.     allTextMat=[]  
  3.     allTypeArray=[]  
  4.     testMat=[]  
  5.     testTypeArray=[]  
  6.     for i in range(1,26):  
  7.         allTextMat.append(file2array('email/spam/%d.txt'%i))  
  8.         allTypeArray.append(1)  
  9.         allTextMat.append(file2array('email/ham/%d.txt'%i))  
  10.         allTypeArray.append(0)  
  11.     for i in range(5):  
  12.         randIndex=int(random.uniform(0,len(allTextMat)))  
  13.         testMat.append(allTextMat[randIndex])  
  14.         testTypeArray.append(allTypeArray[randIndex])  
  15.         del(allTextMat[randIndex])  
  16.         del(allTypeArray[randIndex])  
  17.     return allTextMat,allTypeArray,testMat,testTypeArray  
        接着我们计算所有词的出现次数和在垃圾邮件、正常邮件中分别出现的次数:

[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. def getWordList(allTextMat):  
  2.     wordSet=set()  
  3.     for textVec in allTextMat:  
  4.         wordSet|=set(textVec)  
  5.     return list(wordSet)  
  6.       
  7. def getCountList(wordList,allTextMat,allTypeArray):  
  8.     wordListLen=len(wordList)  
  9.     totalCntList=ones(wordListLen)  
  10.     totalCntList*=2  
  11.     p0CntList=ones(wordListLen)  
  12.     p1CntList=ones(wordListLen)  
  13.     order=0  
  14.     p0Cnt=0  
  15.     p1Cnt=0  
  16.     for textVec in allTextMat:  
  17.         for word in textVec:  
  18.             wordPos=wordList.index(word)  
  19.             totalCntList[wordPos]+=1  
  20.             if allTypeArray[order]==1:  
  21.                 p1CntList[wordPos]+=1  
  22.                 p1Cnt+=1  
  23.             elif allTypeArray[order]==0:  
  24.                 p0CntList[wordPos]+=1  
  25.                 p0Cnt+=1  
  26.         order+=1  
  27.     p0=float(p0Cnt)/(p0Cnt+p1Cnt)  
  28.     p1=1-p0  
  29.     return totalCntList,p0CntList,p1CntList,p0,p1  
        最后我们进行贝叶斯分类:
[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. def bayesClassify(testMat,testTypeArray,totalCntList,p0CntList,p1CntList,p0,p1,wordList):  
  2.     docIndex=0  
  3.     errorCnt=0  
  4.     for testVec in testMat:  
  5.         sum0=0.0  
  6.         sum1=0.0  
  7.         for word in testVec:  
  8.             if word not in wordList:  
  9.                 continue  
  10.             wordPos=wordList.index(word)  
  11.             sum0+=log(float(p0CntList[wordPos]/totalCntList[wordPos]))  
  12.             sum1+=log(float(p1CntList[wordPos]/totalCntList[wordPos]))  
  13.         sum0+=log(p0)  
  14.         sum1+=log(p1)  
  15.         decType=0 if sum0>sum1 else 1  
  16.         if decType != testTypeArray[docIndex]:  
  17.             errorCnt+=1  
  18.             print sum0,sum1  
  19.         docIndex+=1  
  20.     return errorCnt  
        整体的代码如下:
[python] view plaincopy在CODE上查看代码片派生到我的代码片
  1. import os  
  2. import re  
  3. from numpy import *  
  4.   
  5. def file2array(filename):  
  6.     fileReader=open(filename,'r').read()  
  7.     listOfWord=re.split(r'\W*',fileReader)  
  8.     fileArray=[word.lower() for word in listOfWord if len(word)>3]  
  9.     return fileArray  
  10.       
  11. def getAllInfo():  
  12.     allTextMat=[]  
  13.     allTypeArray=[]  
  14.     testMat=[]  
  15.     testTypeArray=[]  
  16.     for i in range(1,26):  
  17.         allTextMat.append(file2array('email/spam/%d.txt'%i))  
  18.         allTypeArray.append(1)  
  19.         allTextMat.append(file2array('email/ham/%d.txt'%i))  
  20.         allTypeArray.append(0)  
  21.     for i in range(5):  
  22.         randIndex=int(random.uniform(0,len(allTextMat)))  
  23.         testMat.append(allTextMat[randIndex])  
  24.         testTypeArray.append(allTypeArray[randIndex])  
  25.         del(allTextMat[randIndex])  
  26.         del(allTypeArray[randIndex])  
  27.     return allTextMat,allTypeArray,testMat,testTypeArray  
  28.       
  29. def getWordList(allTextMat):  
  30.     wordSet=set()  
  31.     for textVec in allTextMat:  
  32.         wordSet|=set(textVec)  
  33.     return list(wordSet)  
  34.       
  35. def getCountList(wordList,allTextMat,allTypeArray):  
  36.     wordListLen=len(wordList)  
  37.     totalCntList=ones(wordListLen)  
  38.     totalCntList*=2  
  39.     p0CntList=ones(wordListLen)  
  40.     p1CntList=ones(wordListLen)  
  41.     order=0  
  42.     p0Cnt=0  
  43.     p1Cnt=0  
  44.     for textVec in allTextMat:  
  45.         for word in textVec:  
  46.             wordPos=wordList.index(word)  
  47.             totalCntList[wordPos]+=1  
  48.             if allTypeArray[order]==1:  
  49.                 p1CntList[wordPos]+=1  
  50.                 p1Cnt+=1  
  51.             elif allTypeArray[order]==0:  
  52.                 p0CntList[wordPos]+=1  
  53.                 p0Cnt+=1  
  54.         order+=1  
  55.     p0=float(p0Cnt)/(p0Cnt+p1Cnt)  
  56.     p1=1-p0  
  57.     return totalCntList,p0CntList,p1CntList,p0,p1  
  58.       
  59. def bayesClassify(testMat,testTypeArray,totalCntList,p0CntList,p1CntList,p0,p1,wordList):  
  60.     docIndex=0  
  61.     errorCnt=0  
  62.     for testVec in testMat:  
  63.         sum0=0.0  
  64.         sum1=0.0  
  65.         for word in testVec:  
  66.             if word not in wordList:  
  67.                 continue  
  68.             wordPos=wordList.index(word)  
  69.             sum0+=log(float(p0CntList[wordPos]/totalCntList[wordPos]))  
  70.             sum1+=log(float(p1CntList[wordPos]/totalCntList[wordPos]))  
  71.         sum0+=log(p0)  
  72.         sum1+=log(p1)  
  73.         decType=0 if sum0>sum1 else 1  
  74.         if decType != testTypeArray[docIndex]:  
  75.             errorCnt+=1  
  76.             print sum0,sum1  
  77.         docIndex+=1  
  78.     return errorCnt  
  79.       
  80. def main():  
  81.     allTextMat,allTypeArray,testMat,testTypeArray=getAllInfo()  
  82.     wordList=getWordList(allTextMat)  
  83.     totalCntList,p0CntList,p1CntList,p0,p1=getCountList(wordList,allTextMat,allTypeArray)  
  84.     print bayesClassify(testMat,testTypeArray,totalCntList,p0CntList,p1CntList,p0,p1,wordList)  
  85.       
  86. if __name__=='__main__':  
  87.     main()  
        我使用的邮件数据来源于《机器学习实战》第四章。感兴趣的同学可以去它官网http://www.manning.com/pharrington/下载数据集。

        以上就是贝叶斯算法的基本介绍。作为本系列的开篇之作,我在表述上可能会有不当之处,还请各位同学在评论中指正。


0 0
原创粉丝点击