前馈神经网络模型-误差逆向传播算法

来源:互联网 发布:能读出身份证的软件 编辑:程序博客网 时间:2024/05/17 06:21

BP算法的核心步骤
1. 求得在特定输入下实际输出与理想输出的平方误差函数。
2. 利用误差函数对神经网络中的阈值以及连接权值根据导数的“链式求导”法则对各种变量求导。
3. 根据梯度下降算法,对极小值进行逼近,当满足条件时,跳出循环

 由于对单个样本来说可以不断更新权值阈值,无限逼近期望值:标准BP算法,计算量大,对于不同样本训练后可能会使得前一个效果变坏。
 对多个样本来说。误差函数是单个样本的累加,所以每计算一次累积误差后对网络权值阈值进行更新:累积BP算法 前期下降快到达极小值附近后下降慢,可能波动
 
 下面以《机器学习 周克华》P84中西瓜分类的例子进行训练。

编号,色泽,根蒂,敲声,纹理,脐部,触感,密度,含糖率,好瓜  1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,0.697,0.46,是  2,乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,0.774,0.376,是  3,乌黑,蜷缩,浊响,清晰,凹陷,硬滑,0.634,0.264,是  4,青绿,蜷缩,沉闷,清晰,凹陷,硬滑,0.608,0.318,是  5,浅白,蜷缩,浊响,清晰,凹陷,硬滑,0.556,0.215,是  6,青绿,稍蜷,浊响,清晰,稍凹,软粘,0.403,0.237,是  7,乌黑,稍蜷,浊响,稍糊,稍凹,软粘,0.481,0.149,是  8,乌黑,稍蜷,浊响,清晰,稍凹,硬滑,0.437,0.211,是  9,乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,0.666,0.091,否  10,青绿,硬挺,清脆,清晰,平坦,软粘,0.243,0.267,否  11,浅白,硬挺,清脆,模糊,平坦,硬滑,0.245,0.057,否  12,浅白,蜷缩,浊响,模糊,平坦,软粘,0.343,0.099,否  13,青绿,稍蜷,浊响,稍糊,凹陷,硬滑,0.639,0.161,否  14,浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,0.657,0.198,否  15,乌黑,稍蜷,浊响,清晰,稍凹,软粘,0.36,0.37,否  16,浅白,蜷缩,浊响,模糊,平坦,硬滑,0.593,0.042,否  17,青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,0.719,0.103,否 
#实例 使用标准BP算法和累积性BP算法在西瓜数据集3.0上分别训练单隐层神经网络并对结果进行比较'''颜色:浅白 青绿 乌黑 1 2 3根底:蜷缩 稍蜷 硬挺 1 2 3敲声: 清脆 浊响 沉闷 1 2 3纹理:模糊 稍微模糊 清晰 1 2 3肚挤: 凹陷,稍凹 平摊 1 2 3触感: 硬 软 1 2'''import numpy as npimport mathx=np.mat(  '2,3,3,2,1,2,3,3,3,2,1,1,2,1,3,1,2;\            1,1,1,1,1,2,2,2,2,3,3,1,2,2,2,1,1;\            2,3,2,3,2,2,2,2,3,1,1,2,2,3,2,2,3;\            3,3,3,3,3,3,2,3,2,3,1,1,2,2,3,1,2;\            1,1,1,1,1,2,2,2,2,3,3,3,1,1,2,3,2;\            1,1,1,1,1,2,2,1,1,2,1,2,1,1,2,1,1;\            0.697,0.774,0.634,0.668,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719;\            0.460,0.376,0.264,0.318,0.215,0.237,0.149,0.211,0.091,0.267,0.057,0.099,0.161,0.198,0.370,0.042,0.103\            ').Tx=np.array(x)xrow,xcol=x.shapey=np.mat('1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0')y=np.array(y).T#定义隐层连接权值v=np.random.random((10,8))w=np.random.random((1,10))#阈值theta1=np.random.random((1,y.shape[1]))theta0=np.random.random((1,10))#学习率alpha=0.01i=1#非线性函数def sigmoid(x):    return 1/(1+np.exp(-x))#改算法为累积性bp算法while 1:    i+=1    b=sigmoid(np.dot(v,x.T)-np.tile(theta0.T,(1,xrow)))    y_o=sigmoid(np.dot(w,b)-np.tile(theta1,(1,17))).T    if abs(np.sum((y-y_o)**2))<0.005:        print('v=%s\nw=%s\ntheta0=%s\ntheta1=%s\ny.T=%s\ny_o.T=%s\n' %(v,w,theta0,theta1,y.T,y_o.T))        print(i)        break    g=y_o*(1-y_o)*(y-y_o)    e=np.dot(g,w)*(b*(1-b)).T    e=e.T  #10*17    w+=alpha*np.dot(g.T,b.T)    theta1-=alpha*sum(g)    v+=alpha*np.dot(e,x)    theta0-=alpha*e.sum(axis=1)

训练结果连接权值阈值以及实际输出比较:

v=[[ 0.37717435  0.36248417  0.94028635  0.72177583  0.01050585  0.6763556   0.81691498  0.24734353] [ 0.46664953  0.52817501  0.67726546  0.74109162  0.75499802  0.82495554   0.11996302  0.06666964] [ 4.22006568 -0.05567354 -3.19805728 -3.27984296  0.30586882  4.02184602  -1.84917339  2.35060779] [ 0.55454507 -0.33480137 -2.3908266   1.31354428  0.20376727  0.84164039  -1.55461145  1.96061611] [ 0.37400772  0.04965319  0.1238434   0.70032288 -0.02324559 -0.15205114   0.41107624  0.26961664] [-2.37426911 -0.64492402  0.75545279  4.67744918  0.37459649 -1.94822074  -0.23898184  1.39323303] [-0.83703887 -0.19784473  0.4722273   1.64318465  0.10852898 -0.52427996   0.53549982  0.58889765] [ 0.85043228  0.31970411  0.98993328  0.6941652   0.26571303  0.52882551   0.81486816  0.80428266] [ 4.16107957 -5.21812224 -3.55653231  3.84789551 -2.1587295   4.85989479  -1.90026038  0.98648877] [ 0.41039923  0.9260909   0.70832484  0.65247399  0.3561349   0.55512193   0.96653516  0.92850265]]w=[[  0.35725147   0.77755131  -7.66382868  -3.79248027   0.14450059   -5.92993026  -1.31719148  -0.0350807   12.43810048   0.09992291]]theta0=[[ 0.85883852  0.41110001  2.52223799  1.84316683  0.96692533  1.19203815   1.05009637  0.65267965  2.13324929  0.97123541]]theta1=[[ 0.47713315]]y.T=[[1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0]]y_o.T=[[ 0.9938711   0.99739941  0.98745967  0.99651209  0.99341251  0.97255214   0.96723021  0.97850395  0.0116238   0.00277448  0.00616275  0.00153108   0.01275759  0.00185673  0.04463849  0.00949311  0.00463293]]训练次数:183201输出结果y_o与理想输出y相比误差较小,符合要求,所以训练成功

标准BP算法没想到好的退出函数,先想想,以后再写

0 0