朴素贝叶斯实现

来源:互联网 发布:怎么查看淘宝店铺扣分 编辑:程序博客网 时间:2024/05/16 04:52

关于朴素贝叶斯的理论理解起来不难,且网上有很多这方面的资源,但是关于朴素贝叶斯的实现代码太少了。这里结合朴素贝叶斯在文本分类中的应用,实现一个朴素贝叶斯分类器,并用这个分类器来预测文本分类结果。


package bayes;


import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Vector;




/**
* 朴素贝叶斯分类器
*/
public class BayesClassifier 
{
private TrainingDataManager tdm; //训练集管理器
private static double zoomFactor = 10.0f;

public BayesClassifier() 
{
tdm =new TrainingDataManager();
}

/**
* 对给定的文本进行分类
* @param text 给定的文本
* @return 分类结果
*/
@SuppressWarnings("unchecked")
public String classify(String text) 
{
String[] terms = null;
terms=ChineseSpliter.split(text, "|").split("\\|");//中文分词处理(分词后结果可能还包含有停用词)
terms = StopWordsHandler.DropStopWords(terms);//去掉停用词,以免影响分类

System.out.println("正在计算...");
String[] Classes = tdm.getTraningClassifications();//分类
float probility = 0.0F;
List<ClassifyResult> crs = new ArrayList<ClassifyResult>();//分类结果
for (int i=1;i<Classes.length;i++) 
{
String Ci = Classes[i];//第i个分类
probility = ClassifyProbability.calcProd(terms, Ci);//计算给定的文本属性集合terms在给定的分类Ci中的分类条件概率
//保存分类结果
ClassifyResult cr = new ClassifyResult();
cr.classification = Ci;//分类
cr.probility = probility;//关键字在分类的条件概率
crs.add(cr);
}

//对最后概率结果进行排序
java.util.Collections.sort(crs,new Comparator() 
{
public int compare(final Object o1,final Object o2) 
{
final ClassifyResult m1 = (ClassifyResult) o1;
final ClassifyResult m2 = (ClassifyResult) o2;
final double ret = m1.probility - m2.probility;
if (ret < 0) 
{
return 1;

else 
{
return -1;
}
}
});
//返回概率最大的分类
return crs.get(0).classification;
}
}



package bayes;


import java.io.IOException;
import jeasy.analysis.MMAnalyzer;  


/**
* 中文分词器
*/
public class ChineseSpliter 
{
/**
* 对给定的文本进行中文分词
* @param text 给定的文本
* @param splitToken 用于分割的标记,如"|"
* @return 分词完毕的文本
*/
public static String split(String text,String splitToken)
{
String result = null;
MMAnalyzer analyzer = new MMAnalyzer();  
try  
        {
result = analyzer.segment(text, splitToken);
}  
        catch (IOException e)  
        {
        e.printStackTrace();
        }
        return result;
}
}


package bayes;


public class ClassConditionalProbability {


private static TrainingDataManagertdm = new TrainingDataManager();

/**

* 计算类条件概率

* @param x 给定的文本属性

* @param c 给定的分类

* @return 给定条件下的类条件概率

*/

public staticfloat calculatePxc(String x, Stringc

{

float ret = 0F;

float Nxc = tdm.getCountContainKeyOfClassification(c,x);//这里的计算效果并不好,只是统计了单词在某篇文章中是否出现了

float Nc = tdm.getTrainingFileCountOfClassification(c);

float V = tdm.getTraningClassifications().length;

ret = (Nxc+1) / (Nc+V);

returnret;

}

}



package bayes;


public class Classification {


public staticvoid main(String[] args) {

// TODO Auto-generated method stub


String text ="据新德里电视台报道,印度内政部一名消息人士透露说,中国近期一直在印中巡逻线附近修建一 座观察哨所,印度对此持反对态度。"+

"11日,印度军队和印藏边防警察部队派人越界拆毁了中国在建的哨所,双方军队遂在这一地区发生对峙。该报道称,这条巡逻线为双方所接受,"+

"中国所修建的哨所位于巡逻线中国一侧";

//String text1="据新德里电视台报道,印度内政部一名消息人士透露说";

BayesClassifier classifier =new BayesClassifier();//构造Bayes分类器

String result = classifier.classify(text);//进行分类

System.out.println("此项属于["+result+"]");

}


}



package bayes;


public class ClassifyProbability {


private staticfloat zoomFactor = 10.0f;

/**

* 计算某篇文档X属于类别Cj得概率,即求p(X|Cj)*p(Cj)的值

* @param X 给定的文本属性集合

* @param Cj 给定的类别

* @return 分类条件概率连乘值

*/

public staticfloat calcProd(String[] X, StringCj)

{

float ret = 1.0f;

// 类条件概率连乘

for (inti=0;i<X.length;i++)

{

//计算p(X|Cj)值

ret *=ClassConditionalProbability.calculatePxc(X[i],Cj)*zoomFactor;

}

// 再乘以先验概率

ret *= PriorProbability.calculatePc(Cj);

returnret;

}

}



package bayes;


/**

* 分类结果

*/

public class ClassifyResult {


public doubleprobility;//分类的概率

public Stringclassification;//分类

public ClassifyResult()

{

this.probility = 0;

this.classification =null;

}

}



package bayes;


public class ClassPriorProbability {


private static TrainingDataManagertdm = new TrainingDataManager();

private staticfinal float M = 0F;

/**

* 计算类条件概率

* @param x 给定的文本属性

* @param c 给定的分类

* @return 给定条件下的类条件概率

*/

public staticfloat calculatePxc(String x, Stringc

{

float ret = 0F;

float Nxc = tdm.getCountContainKeyOfClassification(c,x);

float Nc = tdm.getTrainingFileCountOfClassification(c);

float V = tdm.getTraningClassifications().length;

ret = (Nxc + 1) / (Nc +M + V);

returnret;

}

}



package bayes;


/**

 * 计算先验概率

 * @author wtj

 */


public class PriorProbability 

{

private static TrainingDataManagertdm =new TrainingDataManager();


/**

* 先验概率p=类c下文本数量/训练集的文本数量

* @param c 给定的分类

* @return 给定条件下的先验概率

*/

public staticfloat calculatePc(String c)

{

float ret = 0F;

float Nc = tdm.getTrainingFileCountOfClassification(c);

float N = tdm.getTrainingFileCount();

ret = Nc / N;

returnret;

}

}



package bayes;


import java.util.Vector;


/**

 * 停用词处理

 * @author wtj

 *

 */

public class StopWordsHandler {


privatestatic String[] stopSordList={"|","的","我们","要","自己","之","将","“","”",",","(",")","后","应","到",

"某","后","个","是","位","新","一","两","在","中","或","有","更","好",""};

//停用词判断

public staticboolean isStopWord(String word)

{

for (inti = 0; i < stopSordList.length;i++) 

{

if (word.equalsIgnoreCase(stopSordList[i])) 

{

returntrue;

}

}

returnfalse;

}

/**

* 去掉停用词

* @param text 给定的文本

* @return 去停用词后结果

*/

public static String[] DropStopWords(String[]oldWords)

{

Vector<String> v1 =new Vector<String>();

for(inti=0;i<oldWords.length;++i)

{

if(StopWordsHandler.isStopWord(oldWords[i])==false)

{

v1.add(oldWords[i]);

}

}

String[] newWords =new String[v1.size()];

v1.toArray(newWords);

returnnewWords;

}

}



package bayes;


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Properties;
import java.util.logging.Level;
import java.util.logging.Logger;


/**
* 训练集管理器
*/


public class TrainingDataManager 
{
private String[] traningFileClassifications;//训练语料分类集合
private File traningTextDir;//训练语料存放目录,因为linux/Unix下将一切看作文件
private static String defaultPath = "/Users/wtj/Documents/workspace/www/trainingData";

public TrainingDataManager() 
{
traningTextDir = new File(defaultPath);
if (!traningTextDir.isDirectory()) 
{
throw new IllegalArgumentException("训练语料库搜索失败! [" +defaultPath + "]");
}
traningFileClassifications = traningTextDir.list();//构造函数就是用来初始化定义的成员
}


/**
* 返回训练文本类别,这个类别就是目录名
* @return 训练文本类别
*/
public String[] getTraningClassifications() 
{
return traningFileClassifications;
}

/**
* 根据训练文本类别返回这个类别下的所有训练文本路径(full path)
* @param classification 给定的分类
* @return 给定分类下所有文件的路径(full path)
*/
public String[] getFilesPath(String classification) 
{
File classDir = new File(traningTextDir.getPath() +File.separator +classification);
String[] ret = classDir.list();
for (int i=0;i<ret.length;i++) 
{
ret[i] = traningTextDir.getPath() +File.separator +classification +File.separator +ret[i];
}
return ret;
}


/**
* 计算某个目录下的文件个数
* @param classification 目录名
* @return 该目录下的文件个数
*/
public int getTrainingFileCountOfClassification(String classification)
{
File classDir = new File(traningTextDir.getPath() +File.separator +classification);
return classDir.list().length;
}

/**
* 返回训练文本集中所有的文本数目
* @return 训练文本集中所有的文本数目
*/
public int getTrainingFileCount()
{
int ret = 0;
for (int i = 1; i <traningFileClassifications.length; i++)
{
ret +=getTrainingFileCountOfClassification(traningFileClassifications[i]);
}
return ret;
}

/**
* 返回给定路径的文本文件内容
* @param filePath 给定的文本文件路径
* @return 文本内容
* @throws java.io.FileNotFoundException
* @throws java.io.IOException
*/
public static String getText(String filePath) throws FileNotFoundException,IOException 
{

InputStreamReader isReader =new InputStreamReader(new FileInputStream(filePath),"GBK");
BufferedReader reader = new BufferedReader(isReader);
String aline;
StringBuilder sb = new StringBuilder();

while ((aline = reader.readLine()) != null)
{
sb.append(aline + " ");
}
isReader.close();
reader.close();
return sb.toString();
}





/**
* 返回给定分类中包含关键字/词的训练文本的数目
* @param classification 给定的分类
* @param key 给定的关键字/词
* @return 给定分类中包含关键字/词的训练文本的数目
*/
public int getCountContainKeyOfClassification(String classification,String key) 
{
int ret = 0;
try 
{
String[] filePath = getFilesPath(classification);
for (int j = 0; j < filePath.length; j++) 
{
String text = getText(filePath[j]);
if (text.contains(key)) //这里这样做并不好,实际统计出词频更好
{
ret++;
}
}
}
catch (FileNotFoundException ex) 
{
Logger.getLogger(TrainingDataManager.class.getName()).log(Level.SEVERE, null,ex);


catch (IOException ex)
{
Logger.getLogger(TrainingDataManager.class.getName()).log(Level.SEVERE, null,ex);

}
return ret;
}
}

0 0
原创粉丝点击