【模式识别】贝叶斯分类器的C++实现

来源:互联网 发布:java线程yield 编辑:程序博客网 时间:2024/05/29 16:28
  • 分类问题

    在介绍贝叶斯分类器之前我们首先需要知道分类问题是什么。

    给你一个训练数据集,数据集中有一些样本,每个样本有一个或若干特征x(这里x的是向量),每个样本都属于一个类别wi,现在让你通过这个训练集得到一个分类器,这个分类器具有这样的功能:

    给定一个新的样本(也就是给定一个特征x),你的分类器能判断它属于哪个类别

    其实我们可以把分类器看成是一个函数f(x),函数的返回值就是对应特征所属的类别,所有的分类器的目的都是去寻找一个比较好的函数能更好的描述特征与类别之间的关系。

  • 贝叶斯定理:

    P(wi|x)=P(x|wi)P(wi)kj=1P(x|wj)P(wj)

    P(wi)是,表示在不知道样本特征的情况下,某个样本属于wi类的概率。

    P(x|wi)是类条件概率密度,可以看成是在某一类别的情况下特征的分布(概率密度函数)

    P(wi|x)是后验概率,就是在知道样本的特征的情况下该样本属于某一个特征的概率。

  • 贝叶斯分类器的设计思路

    一句话概括贝叶斯分类器:

    在知道先验概率和类条件概率密度的情况下算后验概率,后验概率最大的类别作为最终类别

    我们的问题是”给定特征判断这个特征所对应的类别”,一个容易想到的思路是算出这个特征属于每个类别的概率,然后取最大的那个类别作为最终的分类。其中”属于某个类别的概率”可以表示成条件概率的形式P(wi|x),也就是贝叶斯定理中的后验概率,接下来就是去求得P(x|wi)P(wi)就行了.

  • 例子

    将人的身高(height)和体重(weight)作为分类特征,分为两组类别,男性(male)和女性(female),现在给你一组男性的身高和体重的数据和女性的身高体重数据作为训练数据集训练出一个贝叶斯分类器。

    男性的数据保存在MALE.TXT,包含若干行数据,每行两个数x1 x2分别表示身高和体重

    女性的数据类似

    假设二者相关,在正态分布假设下估计概率密度函数

    由于假设了两个特征满足正态分布,所以我们可以通过μ1μ2θ1θ2ρ五个参数确定某个类别的样本分布,也就是类条件概率密度P(x|wi)

  • 代码实现

#include<iostream>#include<fstream>#include<cmath>#include<cstdio>using namespace std;const int MAXN=1000;const double pi=3.1415926;ifstream cin1("FEMALE.TXT");ifstream cin2("MALE.TXT");ifstream cin3("test2.txt");ofstream cout1("result.txt");struct HUMAN{    double height;    double weight;};HUMAN female[MAXN];HUMAN male[MAXN];int female_num;int male_num;double P_female;double P_male;struct NORMAL{  double mu1;  double mu2;  double delta1;  double delta2;  double rho;};NORMAL female_normal;NORMAL male_normal;/*读入文件数据*/void In(){    male_num=0;    female_num=0;    while(cin1>>female[female_num+1].height>>female[female_num+1].weight)    {        female_num++;    }    while(cin2>>male[male_num+1].height>>male[male_num+1].weight)    {        male_num++;    }}void Init(){      P_female=0.5;      P_male=0.5;}/*读入样本数量个样本,并求出该样本的二维正态分布*/void Normalization(struct HUMAN *human,int human_num,struct NORMAL &human_normal){    double mu1,mu2,delta1,delta2,rho;    mu1=0;mu2=0;delta1=0;delta2=0,rho=0;    for(int i=1;i<=human_num;i++)    {          mu1+=human[i].height;          mu2+=human[i].weight;    }    mu1/=human_num;    mu2/=human_num;    for(int i=1;i<=human_num;i++)    {          delta1+=(human[i].height-mu1)*(human[i].height-mu1);          delta2+=(human[i].weight-mu2)*(human[i].weight-mu2);    }    delta1/=human_num;    delta2/=human_num;    delta1=sqrt(delta1);    delta2=sqrt(delta2);    for(int i=1;i<=human_num;i++)    {          rho+=human[i].height*human[i].weight;    }    rho/=human_num;    rho-=mu1*mu2;    rho/=(delta1*delta2);    human_normal.mu1=mu1;    human_normal.mu2=mu2;    human_normal.delta1=delta1;    human_normal.delta2=delta2;    human_normal.rho=rho;    cout<<mu1<<" "<<delta1<<" "<<mu2<<" "<<delta2<<" "<<rho<<endl;}/*在分布为normal的条件下特征为(x1,x2)的条件概率*/double P(struct NORMAL &normal,double x1,double x2){    double ans;    double mu1=normal.mu1;    double mu2=normal.mu2;    double delta1=normal.delta1;    double delta2=normal.delta2;    double rho=normal.rho;     rho=0;    ans=(1/(2*pi*delta1*delta2*        sqrt(1-rho*rho)            ))*exp(-1/(2*sqrt(1-rho*rho)   )*(             ((x1-mu1)*(x1-mu1))/(delta1*delta1)   +   ((x2-mu2)*(x2-mu2))/(delta2*delta2)   -    (2*rho*(x1-mu1)*(x2-mu2))/(delta1*delta2)                                   )   );    return ans;}/*归为normal的后验概率t为0表示female,1表示male*/double Posterior_probability1(double x1,double x2,bool t){    double Pw;    struct NORMAL normal;    if(t==0)return (P(female_normal,x1,x2)*P_female)/(P(female_normal,x1,x2)*P_female+P(male_normal,x1,x2)*P_male);    else return (P(male_normal,x1,x2)*P_male)/(P(female_normal,x1,x2)*P_female+P(male_normal,x1,x2)*P_male);}/*判断是哪个类别,返回0表示female,1表示male*/bool Classify(double x1,double x2){    //cout<<Posterior_probability1(x1,x2,0)<<" "<<Posterior_probability1(x1,x2,1)<<endl;    if(Posterior_probability1(x1,x2,0)>=Posterior_probability1(x1,x2,1))return 0;    else return 1;}/*得到错误率并将错误率输出到result.txt中*/void Find_error_rate(){    double height,weight;    char c;    int right_num=0;    int wrong_num=0;    while(cin3>>height>>weight>>c)    {        if(   (c=='f' || c=='F')   && Classify(height,weight)==0)//分类为女性并且正确             right_num++;        else if(   (c=='m' || c=='M')   && Classify(height,weight)==1)//分类为男性并且正确             right_num++;        else            wrong_num++;        cout1<<height<<" "<<weight<<" "<<Classify(height,weight)<<endl;    }    cout<<"error rate is "<<(double)wrong_num/(double)(wrong_num+right_num)<<endl;}int main(){    In();    Init();    Normalization(female,female_num,female_normal);    Normalization(male,male_num,male_normal);    //female_normal.rho=0;    //male_normal.rho=0;    Find_error_rate();    return 0;}
原创粉丝点击