反向传播网络实现

来源:互联网 发布:用友软件售后电话 编辑:程序博客网 时间:2024/06/04 23:25

这个主要是实现训练过程核心代码

function bpn(t1,t2)%t1 is trainingset,t2 is trainingtarget
%BP Summary of this function goes here
%   Detailed explanation goes here
w11=[1 0.78];
w12=[0.8 0.9];
w21=[1 0.6];
b=[0.7 0.6 0.5];
lr=0.6;%leraning rate
%lr=0.2;
%lr=0.05;
sumerr=0;
%err=1;
c=1;%record the number of iteration
while c>0  
    for i=1:1:130
        x11 = w11*t1(:,i);
        x12 = w12*t1(:,i);
        x1=b(1)-(w11(1)*x11 + w11(2)*x11);
        x2=b(2)-(w12(1)*x11 + w12(2)*x11);
       
        Xout(1) = fun(x1);
        Xout(2) = fun(x2);
       
        x21 = Xout*w21';
        x3=b(3)-(w21(1)*x21 + w21(2)*x21);
        f = fun(x3);
        aout(i)=f;
       
        err = t2(i)-f;
        Q3 = f*(1-f)*err;
        sumerr=sumerr+0.5*err^2;
       
       
        w21(1) = w21(1) + lr*Xout(1)*Q3;
        w21(2) = w21(2) + lr*Xout(2)*Q3;
        b(3) = b(3) - lr*Q3;
       
        Q1 = Xout(1)*(1-Xout(1))*w21(1)*Q3;
        w11(1) = w11(1) + lr*t1(1,i)*Q1;
        w11(2) = w11(2) + lr*t1(2,i)*Q1;
        b(1) = b(1) - lr*Q1;
       
        Q2 = Xout(2)*(1-Xout(2))*w21(2)*Q3;
        w12(1) = w12(1) + lr*t1(1,i)*Q2;
        w12(2) = w12(2) + lr*t1(2,i)*Q2;
        b(2) = b(2) - lr*Q2;
    end
        errset(c)=sumerr/130;
        sumerr=0;
   
    if abs(err)<0.2
        break;
    end
    c=c+1;
end
fprintf('the c:');
fprintf('%3.0f   ',c);
plot(errset);
xlabel('the number of iterations');
ylabel('the sum of the squared error');
title('Error rate analysis');
text(130,0.01,['The total iterations are ',num2str(c)]);
fprintf('the target:');
for j=1:1:130
    fprintf('%3.0f   ',t2(j));
end
fprintf('/n');
fprintf('the output:');
for k=1:1:130
    fprintf('%3.0f   ',aout(k));
end
fprintf('/n');
fprintf('the weights:');
disp(w11);
disp(w12);
disp(w21);
fprintf('the bias:');
disp(b);
 plot(aout,'r*');%to compare the target and output by ploting them respectively
 hold;
 plot(t2,'o');
end