斯坦福机器学习笔记1:GDA高斯判别分析算法的原理及matlab程序实现

来源:互联网 发布:天刀精致女性捏脸数据 编辑:程序博客网 时间:2024/06/04 18:17

ps:我本身没有系统的学过matlab编程,所以有的方法,比如求均值用mean()函数之类的方法都是用很笨的方法实现的,所以有很多需要改进的地方,另外是自学实现的程序,可能有的地方我理解错误,如果有错误请提出来,大家一起学习微笑,本人qq553566286

首先,本文用到的数据是

train=[0 0;2 4;3 3;3 4;4 2;44;4 3;5 3;6 2;7 1;2 9;3 8;4 6;4 7;5 6;5 8;6 6;7 4;8 4;10 10];#训练数据

group=[0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1]';#每一行训练数据所对应的类别

将类别为0的数据用蓝点表示出来,其中第一列作为横坐标,第二列作为纵坐标,类别1的用黑点表示出来

图像表示如下

虽然数据的表示并不太符合高斯分布,一会儿看到结果的时候会发现准确率还是可以的,这里只是讲讲原理。

首先根据极大似然估计得到各参数的运算如下


其中fai是标签为1的像素的比例,mu0是所有标签为0的自变量的平均值,这里是一个二维的向量,因为train的每个样本有两个变量。Sigma是方差


对上面的公式进行编程实现

求fai

fai=num/length(group);

求mu0

sum1=0;sum2=0;for i=1:m    sum1=sum1+(1-group(i))*train(i,1);    sum2=sum2+(1-group(i))*train(i,2);endmu01=sum1/(m*fai);mu02=sum2/(m*fai);mu0=[mu01,mu02];

求mu1

sum1=0;sum2=0;for i=1:m    sum1=sum1+group(i)*train(i,1);    sum2=sum2+group(i)*train(i,2);endmu11=sum1/(m*fai);mu12=sum2/(m*fai);mu1=[mu11,mu12];

求sigma

sigmasum=[0,0;0,0];for i=1:m    sigmasum=sigmasum+(train(i,:)'-(mu1)')*(train(i,:)'-(mu1)')';endsigma=sigmasum/m;

下面将mu0和mu1都用红点plot在刚刚的点图中,可以看到m0和m1都是在中心的

plot(mu11,mu12,'r*');

接下来,在图中画出以m0,sigma和m1,sigma为参数的高斯等高线图像。

[x y]=meshgrid(linspace(0,10,50)',linspace(0,10,50)'); X=[x(:) y(:)]; z1=mvnpdf(X,mu0,sigma);contour(x,y,reshape(z1,50,50),4);hold on;

上述代码具体的含义自行百度,可以改改参数看看对结果的影响

结果如下:




那我们来试试几个点,先看[4 3]这个点,我们可以很明显的看到他应该是属于第0类别的中心部分;

我们要求p(x丨y=0)和p(x丨y=0),观察他们谁大谁小,以确定在哪一类。

公式如下:



对上述式子进行编程

test=[4 3];p0=exp(-(test'-(mu0)')'*sigma^-1*(test'-(mu0)')/2)/(((2*pi)^(2/2))*det(sigma)^(1/2))p1=exp(-(test'-(mu1)')'*sigma^-1*(test'-(mu1)')/2)/(((2*pi)^(2/2))*det(sigma)^(1/2))if p0>p1    disp('属于第零类');else    disp('属于第一类');end
运行结果:

看到这里差别0.008左右,看似不大

再看一个[6.5 5]这个点,可以看到他是属于第一类的,而且和第零类差别很小

可以看到这里只差0.003,由此可见,分类效果还是很好的,这也印证了所说的生成学习所需要的数据量比较少、

 完整程序如下


clear;clc;close all;train=[0 0;2 4;3 3;3 4;4 2;4 4;4 3;5 3;6 2;7 1;2 9;3 8;4 6;4 7;5 6;5 8;6 6;7 4;8 4;10 10];group=[0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1]';num=0;fai=0;m=length(group);for i=1:m    if(group(i,1)==0)        plot(train(i,1),train(i,2),'b*');        hold on    end    if(group(i,1)==1)        plot(train(i,1),train(i,2),'k*');        hold on    endendfor i=1:m    if(group(i,1)==1)        num=num+1;    endendfai=num/length(group);sum1=0;sum2=0;for i=1:m    sum1=sum1+(1-group(i))*train(i,1);    sum2=sum2+(1-group(i))*train(i,2);endmu01=sum1/(m*fai);mu02=sum2/(m*fai);mu0=[mu01,mu02];plot(mu01,mu02,'r*');hold onsum1=0;sum2=0;for i=1:m    sum1=sum1+group(i)*train(i,1);    sum2=sum2+group(i)*train(i,2);endmu11=sum1/(m*fai);mu12=sum2/(m*fai);mu1=[mu11,mu12];plot(mu11,mu12,'r*');% sigma=cov(train(:,1),train(:,2));sigmasum=[0,0;0,0];for i=1:m    sigmasum=sigmasum+(train(i,:)'-(mu1)')*(train(i,:)'-(mu1)')';endsigma=sigmasum/m;[x y]=meshgrid(linspace(0,10,50)',linspace(0,10,50)'); X=[x(:) y(:)]; z1=mvnpdf(X,mu0,sigma);contour(x,y,reshape(z1,50,50),4);hold on;[x y]=meshgrid(linspace(0,10,50)',linspace(0,10,50)'); X=[x(:) y(:)];z2=mvnpdf(X,mu1,sigma);contour(x,y,reshape(z2,50,50),4);hold offtest=[6.5 5];p0=exp(-(test'-(mu0)')'*sigma^-1*(test'-(mu0)')/2)/(((2*pi)^(2/2))*det(sigma)^(1/2))p1=exp(-(test'-(mu1)')'*sigma^-1*(test'-(mu1)')/2)/(((2*pi)^(2/2))*det(sigma)^(1/2))if p0>p1    disp('属于第零类');else    disp('属于第一类');end



阅读全文
1 0