最简单的贝叶斯分类器MATLAB实现

来源:互联网 发布:怎么查自己的淘宝账号 编辑:程序博客网 时间:2024/05/17 07:29

贝叶斯(Baysian)分类器[1]是一种理论上比较简单的分类器。但是结合不同的网络结构和概率模形,它又可以演化成非常复杂的分类体系。本短文主要演示Baysian + Gaussian如何解两类问题。




其中,分母部分主要用于归一化。p(y)为先验概率(prior), p(x|y)为条件概率或称之为类概率密度(即已知x是哪一类的情况下p(x)的概率密度)。 在本文中,假设p(x|y)是高斯分布,即[2]:



而p(y)则采用伯努利(Bernoulli)分布[3]:



其中最大似然估计后得到的\eta即为第0类中训练样本的个数占总样本数的百分比。 求得五个参数后,就可能通过比较后验概率得到任意样本x的类别:


.



当f(x) 大于0时即表示


,


此时把样本x归为第0类,否则归为第1类。


下面通过Matlab程序进行演示:


训练的代码:

function [model_pos,model_neg ] = FindGuassianModel( x,y )
%FINDGUASSIANMODULE Summary of this function goes here
%   Detailed explanation goes here

x_pos = x(:,y==1);
model_pos.mu = mean(x_pos,2);
model_pos.var = cov(x_pos');
model_pos.prior = length(x_pos)/length(x);




x_neg = x(:,y~=1);
model_neg.mu = mean(x_neg,2);
model_neg.var = cov(x_neg');
model_neg.prior = length(x_neg)/length(x);

end


计算分类误差:

function [err,h] = FindModelError(model_pos,model_neg, x,y )
%FINDGUASSIANMODULE Summary of this function goes here
%   Detailed explanation goes here

mu1 = model_pos.mu;
sigma1 = model_pos.var;
p1 = model_pos.prior;

mu2 = model_neg.mu;
sigma2 = model_neg.var;
p2 = model_neg.prior;

bias = 0.5*log(det(sigma2))-0.5*log(det(sigma1))+log(p1/p2);
err = 0;
h = zeros(size(y));
for i=1:length(y)
   c = bias + 0.5*(x(:,i)-mu2)'/sigma2*(x(:,i)-mu2) - 0.5*(x(:,i)-mu1)'/sigma1*(x(:,i)-mu1);
   if c > 0
       h(i) = 1;
   else
       h(i) = -1;
   end
   if h(i)~=y(i)
       err = err + 1;
   end   
    
end

end

演示主程序:


%%
clc;
clear;
close all;




%% generate random data
shift =3.0;
n = 2;%2 dim
sigma = 1;
N = 500;
x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];
y = [ones(N/2,1);-ones(N/2,1)];


%show the data
figure;
plot(x(1,1:N/2),x(2,1:N/2),'rs');
hold on;
plot(x(1,1+N/2:N),x(2,1+N/2:N),'go');
title('2d training data');
legend('Positve samples','Negative samples','Location','SouthEast');




% model fitting using maximum likelihood
[model_pos,model_neg] = FindGuassianModel(x,y);

%% test on new dataset, same distribution

n = 2;%2 dim
%y = 1./exp(-w'*x+b)
sigma = 2;
N = 500;
x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];
y = [ones(N/2,1);-ones(N/2,1)];
figure;
plot(x(1,1:N/2),x(2,1:N/2),'rs');
hold on;
plot(x(1,1+N/2:N),x(2,1+N/2:N),'go');
title('2d testing data');
hold on;

%% gaussian model as a baseline
[err,h] = FindModelError(model_pos,model_neg,x,y);
fprintf('Baysian error on test data set: %f\n',err/N);
x_pos = x(:,h==1);
x_neg = x(:,h~=1);
plot(x_pos(1,:),x_pos(2,:),'r.');
hold on;
plot(x_neg(1,:),x_neg(2,:),'g.');
legend('Positve samples','Negative samples','Positve samples as predicted','Negative samples as predicted','Location','SouthEast');


最后的测试结果:


从测试结是上看,大部分样本都能分类正确(同色的点在同色的圆或方框中),只有0.8%的点分类错误。


本文的所有代码可在我的资源页http://download.csdn.net/detail/ranchlai/6018299下载


[1] http://en.wikipedia.org/wiki/Bayes_classifier

[2]http://en.wikipedia.org/wiki/Multivariate_normal_distribution

[3]http://en.wikipedia.org/wiki/Bernoulli_distribution

0 0
原创粉丝点击