BP算法(python)

来源:互联网 发布:中国对外贸易数据分析 编辑:程序博客网 时间:2024/06/10 17:11
import numpy as npdef nonlin(x, deriv=False):if (deriv == True):return x * (1 - x) #如果deriv为true,求导数return 1 / (1 + np.exp(-x))X = np.array([[0.35],[0.9]]) #输入层y = np.array([[0.5]]) #输出值np.random.seed(1)W0 = np.array([[0.1,0.8],[0.4,0.6]])W1 = np.array([[0.3,0.9]])print 'original ',W0,'\n',W1for j in xrange(100):l0 = X #相当于文章中x0l1 = nonlin(np.dot(W0,l0)) #相当于文章中y1l2 = nonlin(np.dot(W1,l1)) #相当于文章中y2l2_error = y - l2Error = 1/2.0*(y-l2)**2print "Error:",Errorl2_delta = l2_error * nonlin(l2, deriv=True) #this will backpack#print 'l2_delta=',l2_deltal1_error = l2_delta*W1; #反向传播l1_delta = l1_error * nonlin(l1, deriv=True)W1 += l2_delta*l1.T; #修改权值W0 += l0.T.dot(l1_delta)print W0,'\n',W1