朴素贝叶斯算法Java 实现

来源:互联网 发布:电脑图片查看软件 编辑:程序博客网 时间:2024/06/04 17:40

对于朴素贝叶斯算法相信做数据挖掘和推荐系统的小伙们都耳熟能详了,算法原理我就不啰嗦了。我主要想通过java代码实现朴素贝叶斯算法,思想:

1. 用javabean +Arraylist 对于训练数据存储

2. 对于样本数据训练

具体的代码如下:

package NB;/** * 训练样本的属性 javaBean * */public class JavaBean {  int age;  String income;  String student;  String credit_rating;  String buys_computer; public JavaBean(){  }public JavaBean(int age,String income,String student,String credit_rating,String buys_computer){this.age=age;this.income=income;this.student=student;this.credit_rating=credit_rating;this.buys_computer=buys_computer;}    public int getAge() {return age;}public void setAge(int age) {this.age = age;}public String getIncome() {return income;}public void setIncome(String income) {this.income = income;}public String getStudent() {return student;}public void setStudent(String student) {this.student = student;}public String getCredit_rating() {return credit_rating;}public void setCredit_rating(String credit_rating) {this.credit_rating = credit_rating;}public String getBuys_computer() {return buys_computer;}public void setBuys_computer(String buys_computer) {this.buys_computer = buys_computer;}@Overridepublic String toString() {return "JavaBean [age=" + age + ", income=" + income + ", student="+ student + ", credit_rating=" + credit_rating + ", buys_computer="+ buys_computer + "]";}        }
算法实现的部分:

package NB;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.util.ArrayList;public class TestNB {/**data_length * 算法的思想 */public static  ArrayList<JavaBean> list = new ArrayList<JavaBean>();;static int data_length=0;public static void main(String[] args) {// 1.读取数据,放入list容器中File file = new File("E://test.txt");txt2String(file);//数据测试样本testData(25,"Medium","Yes","Fair");}    // 读取样本数据public static void txt2String(File file) {try {BufferedReader br = new BufferedReader(new FileReader(file));// 构造一个BufferedReader类来读取文件String s = null;while ((s = br.readLine()) != null) {// 使用readLine方法,一次读一行data_length++; splitt(s);}br.close();} catch (Exception e) {e.printStackTrace();}}// 存入ArrayList中  public static void splitt(String str){           String strr = str.trim();        String[] abc = strr.split("[\\p{Space}]+");        int age=Integer.parseInt(abc[0]);        JavaBean bean=new JavaBean(age, abc[1], abc[2], abc[3], abc[4]);        list.add(bean);                  }  // 训练样本,测试  public static void testData(int age,String a,String b,String c){  //训练样本    int number_yes=0;  int bumber_no=0;   // age情况 个数  int num_age_yes=0;  int num_age_no=0;  // income   int num_income_yes=0;  int num_income_no=0;  // student   int num_student_yes=0;  int num_stdent_no=0;  //credit  int num_credit_yes=0;  int num_credit_no=0;    //遍历List 获得数据  for(int i=0;i<list.size();i++){    JavaBean bb=list.get(i);    if(bb.getBuys_computer().equals("Yes")){ //Yes    number_yes++;            if(bb.getIncome().equals(a)){//income            num_income_yes++;            }    if(bb.getStudent().equals(b)){//student    num_student_yes++;    }    if(bb.getCredit_rating().equals(c)){//credit    num_credit_yes++;    }    if(bb.getAge()==age){//age    num_age_yes++;    }            }else {//No    bumber_no++;    if(bb.getIncome().equals(a)){//income            num_income_no++;            }    if(bb.getStudent().equals(b)){//student    num_stdent_no++;    }    if(bb.getCredit_rating().equals(c)){//credit    num_credit_no++;    }    if(bb.getAge()==age){//age    num_age_no++;    }    }    }      System.out.println("购买的历史个数:"+number_yes);    System.out.println("不买的历史个数:"+bumber_no);        System.out.println("购买+age:"+num_age_yes);    System.out.println("不买+age:"+num_age_no);        System.out.println("购买+income:"+num_income_yes);    System.out.println("不买+income:"+num_income_no);        System.out.println("购买+stundent:"+num_student_yes);    System.out.println("不买+student:"+num_stdent_no);        System.out.println("购买+credit:"+num_credit_yes);    System.out.println("不买+credit:"+num_credit_no);        //// 概率判断    double buy_yes=number_yes*1.0/data_length; // 买的概率double buy_no=bumber_no*1.0/data_length; //  不买的概率    System.out.println("训练数据中买的概率:"+buy_yes);    System.out.println("训练数据中不买的概率:"+buy_no);/// 未知用户的判断    double nb_buy_yes=(1.0*num_age_yes/number_yes)*(1.0*num_income_yes/number_yes)*(1.0*num_student_yes/number_yes)*(1.0*num_credit_yes/number_yes)*buy_yes;           double nb_buy_no=(1.0*num_age_no/bumber_no)*(1.0*num_income_no/bumber_no)*(1.0*num_stdent_no/bumber_no)*(1.0*num_credit_no/bumber_no)*buy_no;           System.out.println("新用户买的概率:"+nb_buy_yes);    System.out.println("新用户不买的概率:"+nb_buy_no);    if(nb_buy_yes>nb_buy_no){    System.out.println("新用户买的概率大");    }else {    System.out.println("新用户不买的概率大");}      }  }

对于样本数据:

25  High    No  Fair       No25  High    No  Excellent  No33  High    No  Fair       Yes41  Medium  No  Fair       Yes     41  Low     Yes Fair       Yes41  Low     Yes Excellent  No33  Low     Yes Excellent  Yes25  Medium  No  Fair       No25  Low     Yes Fair       Yes41  Medium  Yes Fair       Yes25  Medium  Yes Excellent  Yes33  Medium  No  Excellent  Yes33  High    Yes Fair       Yes41  Medium  No  Excellent  No

对于未知用户的数据得出的结果:

购买的历史个数:9不买的历史个数:5购买+age:2不买+age:3购买+income:4不买+income:2购买+stundent:6不买+student:1购买+credit:6不买+credit:2训练数据中买的概率:0.6428571428571429训练数据中不买的概率:0.35714285714285715新用户买的概率:0.028218694885361547新用户不买的概率:0.006857142857142858新用户买的概率大








0 1