机器学习实战朴素贝叶斯的java实现

来源:互联网 发布:切割大小头怎样编程 编辑:程序博客网 时间:2024/05/29 14:28
package com.haolidong.Bayes;import java.util.ArrayList;/** *  * @author haolidong * @Description: [该类主要用于保存特征信息] * @parameter data: [主要保存特征矩阵] */public class Matrix {public ArrayList<ArrayList<String>> data;public Matrix() {// TODO Auto-generated constructor stubdata = new ArrayList<ArrayList<String>>();}}
package com.haolidong.Bayes;import java.util.ArrayList;/** *  * @author haolidong * @Description: [该类主要用于保存特征信息以及标签值] * @parameter labels: [主要保存标签值] */public class CreateDataSet extends Matrix {public ArrayList<String> labels;public CreateDataSet() {// TODO Auto-generated constructor stubsuper();labels = new ArrayList<String>();}/** * @author haolidong * @Description: [机器学习实战决策树第一个案例的数据] */public void initTest() {ArrayList<String> ab1 = new ArrayList<String>();ArrayList<String> ab2 = new ArrayList<String>();ArrayList<String> ab3 = new ArrayList<String>();ArrayList<String> ab4 = new ArrayList<String>();ArrayList<String> ab5 = new ArrayList<String>();ArrayList<String> ab6 = new ArrayList<String>();ab1.add("my");ab1.add("dog");ab1.add("has");ab1.add("flea");ab1.add("problems");ab1.add("help");ab1.add("please");ab2.add("maybe");ab2.add("not");ab2.add("take");ab2.add("him");ab2.add("to");ab2.add("dog");ab2.add("park");ab2.add("stupid");ab3.add("my");ab3.add("dalmation");ab3.add("is");ab3.add("so");ab3.add("cute");ab3.add("I");ab3.add("love");ab3.add("him");ab4.add("stop");ab4.add("posting");ab4.add("stupid");ab4.add("worthless");ab4.add("garbage");ab5.add("mr");ab5.add("licks");ab5.add("ate");ab5.add("my");ab5.add("steak");ab5.add("how");ab5.add("to");ab5.add("stop");ab5.add("him");ab6.add("quit");ab6.add("buying");ab6.add("worthless");ab6.add("dog");ab6.add("food");ab6.add("stupid");data.add(ab1);data.add(ab2);data.add(ab3);data.add(ab4);data.add(ab5);data.add(ab6);labels.add("0");labels.add("1");labels.add("0");labels.add("1");labels.add("0");labels.add("1");}}
package com.haolidong.Bayes;import java.util.ArrayList;/** *  * @parameter p0Vect 类别0的特征向量(概率向量) * @parameter p1Vect 类别1的特征向量(概率向量) * @parameter pAbusive 正样本(为1的样本)的比例 * @author haolidong   * @Description: [该类主要用于保存特征信息] * @parameter data: [主要保存特征矩阵] */public class TrainNB0DataSet {public ArrayList<Double> p0Vect;public ArrayList<Double> p1Vect;public double pAbusive;public TrainNB0DataSet() {p0Vect = new ArrayList<Double>();p1Vect = new ArrayList<Double>();pAbusive = 0.0;}}

package com.haolidong.Bayes;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;import java.util.HashSet;public class Bayes {public static void main(String[] args) {spamTest();}/** * @param end  从0到end的范围中产生num个不重复的随机数 * @param num  num个随机数 * @return 返回产生的n个随机数 * @author haolidong * @Description: [从0到end的范围中产生num个不重复的随机数] */public static HashSet<Integer> randomdif(int end,int num){HashSet<Integer> rndint = new HashSet<Integer>();rndint.size();while ( rndint.size() < num ) {rndint.add((int) (Math.random()*end));}return rndint;}/** * @author haolidong * @Description: [垃圾邮件分类测试] */public static void spamTest(){ArrayList<String> fullText = new ArrayList<String>();CreateDataSet DataSet = new CreateDataSet();for (int i = 1; i < 26; i++) {ArrayList<String> hamWordList = new ArrayList<String>();ArrayList<String> spamWordList = new ArrayList<String>();String hamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\ham\\"+i+".txt");String spamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\spam\\"+i+".txt");hamWordList = textParse(spamPath, 2);DataSet.data.add(hamWordList);DataSet.labels.add("1");for (int j = 0; j < hamWordList.size(); j++) {fullText.add(hamWordList.get(j));}spamWordList=textParse(hamPath, 2);DataSet.data.add(spamWordList);DataSet.labels.add("0");for (int j = 0; j < spamWordList.size(); j++) {fullText.add(spamWordList.get(j));}}//获取词典HashSet<String> vocabList = new HashSet<String>();vocabList = createVocabList(DataSet);HashSet<Integer> rndint = new HashSet<Integer>();//随机产生10个测试集,其余的为训练集rndint = randomdif(50,10);Matrix testMatrix = new Matrix();Matrix trainMatrix = new Matrix();ArrayList<String> trainLabels = new ArrayList<String>();ArrayList<String> testLabels = new ArrayList<String>();Matrix testMatrixTrans = new Matrix();Matrix trainMatrixTrans = new Matrix();for(Integer i:rndint){testMatrix.data.add(DataSet.data.get(i));testLabels.add(DataSet.labels.get(i));}for (int i = 0; i < DataSet.data.size(); i++) {if(!rndint.contains(i)){trainMatrix.data.add(DataSet.data.get(i));trainLabels.add(DataSet.labels.get(i));}}//转化到0 1矩阵for (int i = 0; i < trainMatrix.data.size(); i++) {trainMatrixTrans.data.add(setOfWords2Vec(vocabList,trainMatrix.data.get(i)));}for (int i = 0; i < testMatrix.data.size(); i++) {testMatrixTrans.data.add(setOfWords2Vec(vocabList,testMatrix.data.get(i)));}//训练集的训练TrainNB0DataSet td = new TrainNB0DataSet();td = trainNB0(trainMatrixTrans,trainLabels);//对测试集进行测试int errorCount=0;for (int i = 0; i < testMatrixTrans.data.size(); i++) {int num=classifyNB(testMatrixTrans.data.get(i), td.p0Vect, td.p1Vect, td.pAbusive);System.out.println("the predict:"+num+" , the real:"+testLabels.get(i));if(num!=Integer.parseInt(testLabels.get(i))){errorCount++;}}System.out.println("the errorRate is:"+1.0*errorCount/testMatrixTrans.data.size());}public static ArrayList<String> textParse(String fileName,int moreThan){ArrayList<String> strSplitList = new ArrayList<String>();String s = readFile(fileName);strSplitList = extractStrlist(s,moreThan);return strSplitList;}/** * @param fileName  输入的完整文件路径 * @return 所有的文件内容的字符串 * @author haolidong * @Description: [一行一行读取文件,然后用字符串全部串起来返回,每一行之间使用空格分割] */public static String readFile(String fileName) {File file = new File(fileName);BufferedReader reader = null;String s = new String();try {reader = new BufferedReader(new FileReader(file));String tempString = null;// 一次读入一行,直到读入null为文件结束while ((tempString = reader.readLine()) != null) {//加上" "是为了和下面一段的字符进行区分s=s+tempString+" ";}reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}return s;}/** * @param inputString 输入的字符串 * @param moreThan    只有超过moreThan的字符串才会被保留 * @return    分割好的数据串 * @author haolidong * @Description: [读取一个字符串,进行分割,去掉除了字母数字以外的字符数组,而且所有的字符都改成小写] */public static ArrayList<String> extractStrlist(String inputString,int moreThan) {ArrayList<String> strSplitList = new ArrayList<String>();String regEx = "\\W*";String sentence="";//String inputString = "This book is the best book on M.L. I have";String[] predel = inputString.split(regEx);for (int i = 0; i < predel.length; i++) {if(predel[i].equals(""))sentence+=" ";elsesentence+=predel[i];}String[] strSplit=sentence.split(" ");for (int i = 0; i < strSplit.length; i++) {if(strSplit[i].length()>moreThan) {strSplitList.add(strSplit[i].toLowerCase());}}return strSplitList;}/** * @param vec2Classify   需要进行分类的向量 * @param p0Vec          类别0的权值向量 * @param p1Vec          类别1的权值向量 * @param pClass1                            类别1所占的比重 * @return               返回最后的分类结果 * @author haolidong      * @Description: [计算在每一类中最后的概率返回最大的所对应的标签] */public static int classifyNB(ArrayList<String> vec2Classify, ArrayList<Double> p0Vec, ArrayList<Double> p1Vec,double pClass1) {double p1 = 0.0;double p0 = 0.0;for (int i = 0; i < vec2Classify.size(); i++) {p1 = p1 + Double.parseDouble(vec2Classify.get(i)) * p1Vec.get(i);p0 = p0 + Double.parseDouble(vec2Classify.get(i)) * p0Vec.get(i);}p1 = p1 + Math.log(pClass1);p0 = p0 + Math.log(1 - pClass1);if (p1 > p0)return 1;elsereturn 0;}/** * @param trainMatrix      训练矩阵 * @param trainCategory    训练目录标签 * @return                 返回最后训练结果,包括每一类的特征矩阵以及每一类的比重情况 * @author haolidong      * @Description: [贝叶斯分类的重点函数,数据集的训练,返回特征矩阵和向量] */public static TrainNB0DataSet trainNB0(Matrix trainMatrix, ArrayList<String> trainCategory) {int numTrainDocs = trainMatrix.data.size();int numWords = trainMatrix.data.get(0).size();TrainNB0DataSet resultSet = new TrainNB0DataSet();ArrayList<Double> p0Num = new ArrayList<Double>();ArrayList<Double> p1Num = new ArrayList<Double>();double trainCategorySum = 0.0;for (int i = 0; i < trainCategory.size(); i++) {trainCategorySum = trainCategorySum + Double.parseDouble(trainCategory.get(i));}resultSet.pAbusive = trainCategorySum / numTrainDocs;for (int i = 0; i < numWords; i++) {p0Num.add(1.0);p1Num.add(1.0);}double p0Denom = 2.0;double p1Denom = 2.0;for (int i = 0; i < numTrainDocs; i++) {if (trainCategory.get(i).equals("1")) {for (int j = 0; j < numWords; j++) {p1Num.set(j, p1Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));}} else {for (int j = 0; j < numWords; j++) {p0Num.set(j, p0Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));}}}for (int i = 0; i < numWords; i++) {p0Denom += p0Num.get(i);p1Denom += p1Num.get(i);}p0Denom = p0Denom - numWords;p1Denom = p1Denom - numWords;for (int i = 0; i < numWords; i++) {resultSet.p0Vect.add(Math.log(p0Num.get(i) / p0Denom));resultSet.p1Vect.add(Math.log(p1Num.get(i) / p1Denom));}return resultSet;}/** * @param vocabSet       字典 * @param inputSet       输入数据集 * @return               返回与字典一一对应的数据集 * @author haolidong      * @Description: [生成一个全部为0的字典,把字典中数据集中有的字符串设置为1,其他的设置为0,返回设置完的字典] */public static ArrayList<String> setOfWords2Vec(HashSet<String> vocabSet, ArrayList<String> inputSet) {ArrayList<String> returnVec = new ArrayList<String>();boolean flag;for (String value : vocabSet) {flag = false;for (int i = 0; i < inputSet.size(); i++) {if (inputSet.get(i).equals(value)) {returnVec.add("1");flag = true;break;}}if (flag == false) {returnVec.add("0");}}return returnVec;}/** * @param dataSet    输入数据集 * @return           字典 * @author haolidong      * @Description: [输入数据集,数据有比较大的重复,然后去掉重复的数据,最后生成字典] */public static HashSet<String> createVocabList(Matrix dataSet) {HashSet<String> vocabSet = new HashSet<String>();for (int i = 0; i < dataSet.data.size(); i++) {for (int j = 0; j < dataSet.data.get(i).size(); j++) {vocabSet.add(dataSet.data.get(i).get(j));}}return vocabSet;}/** * @author haolidong      * @Description: [对于生成字典功能的测试] */public static void testVocabList() {CreateDataSet dataSet = new CreateDataSet();dataSet.initTest();HashSet<String> vocabSet = new HashSet<String>();vocabSet = createVocabList(dataSet);System.out.println(vocabSet);}/** * @author haolidong      * @Description: [对于输入字符集转化成字典的测试] */public static void testWord2Vec() {CreateDataSet dataSet = new CreateDataSet();dataSet.initTest();HashSet<String> vocabSet = new HashSet<String>();ArrayList<String> returnVec = new ArrayList<String>();vocabSet = createVocabList(dataSet);returnVec = setOfWords2Vec(vocabSet, dataSet.data.get(0));System.out.println(returnVec);}/** * @author haolidong      * @Description: [对于样本训练的测试] */public static void testTrain() {CreateDataSet dataSet = new CreateDataSet();Matrix trainMatrix = new Matrix();dataSet.initTest();HashSet<String> vocabSet = new HashSet<String>();vocabSet = createVocabList(dataSet);for (int i = 0; i < dataSet.data.size(); i++) {trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));}trainNB0(trainMatrix, dataSet.labels);}/** * @author haolidong      * @Description: [对于样本分类的测试] */public static void testingNB() {CreateDataSet dataSet = new CreateDataSet();TrainNB0DataSet td = new TrainNB0DataSet();ArrayList<String> testEntry = new ArrayList<String>();Matrix trainMatrix = new Matrix();dataSet.initTest();HashSet<String> vocabSet = new HashSet<String>();vocabSet = createVocabList(dataSet);for (int i = 0; i < dataSet.data.size(); i++) {trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));}td = trainNB0(trainMatrix, dataSet.labels);testEntry.add("love");testEntry.add("my");testEntry.add("dalmation");testEntry = setOfWords2Vec(vocabSet, testEntry);System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));testEntry.clear();testEntry.add("stupid");testEntry.add("garbage");testEntry = setOfWords2Vec(vocabSet, testEntry);System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));}}




0 0