bp神经网络 python初探

来源:互联网 发布:u盘数据丢失 编辑:程序博客网 时间:2024/05/17 05:04

最近抽空学习了一下bp神经网络,并且用python实现了一下。

加入了动量项和自调节步长,都是为了拟合的更快一些(这里指定了次数)

还有就是误差稍微大了点,估计是没有标准化的缘故

python的numpy包里的矩阵乘法有些坑人,行列不分;被坑了很久,只好看了一下网上大神的博客,发现都是把矩阵拆开计算的。

from numpy import *import mathimport random as rd#数据集的处理,filename是文件名,我这里是txt文件def loadDataSet(filename):    num = len(open(filename).readline().split('\t')) - 1    xArr = []; yArr = []    f = open(filename)    for each in f.readlines():        link = each.strip().split('\t')        # print link[1]        current = []        current.append(1) #这里加了一项        for i in range(num):            current.append(float(link[i]))        xArr.append(current)        yArr.append(link[-1])    return xArr, yArr#将数据集分为训练集和测试集(1份测试,9份训练)def get_data_test(xArr, yArr, k):    xTrain = []; xTest = []    yTrain = []; yTest = []    Num = shape(xArr)[0]    for i in range(Num):        if i < 0.1*k*Num or i >= 0.1*(k+1)*Num:            xTrain.append(xArr[i])            yTrain.append(yArr[i])        else:            xTest.append(xArr[i])            yTest.append(yArr[i])    return xTrain, yTrain, xTest, yTest#用双曲正切函数代替了S型函数def sigmoid(x):    return math.tanh(x)#假设函数(双曲正切)的导数def dsigmoid(x):    return 1 - x**2class NN:    def __init__(self, ni, nh, no):        self.ni = ni        self.nh = nh        self.no = no        self.prerror = 0.0        self.ai = zeros((self.ni, 1))        self.ah = zeros((self.nh, 1))        self.ao = zeros((self.no, 1))        self.wh = zeros((self.ni, self.nh))        self.wo = zeros((self.nh, self.no))        self.wh_copy = zeros((self.ni, self.nh))        self.wo_copy = zeros((self.nh, self.no))        for i in range(self.ni):            for j in range(self.nh):                self.wh[i][j] = rd.uniform(-0.5, 0.5)        for i in range(self.nh):            for j in range(self.no):                self.wo[i][j] = rd.uniform(-0.5, 0.5)        self.ch = zeros((self.ni, self.nh))        self.co = zeros((self.nh, self.no))    #前向传播,计算每个单元的输入值    def forward_propagate(self, input):        self.ai = mat(input).T        for j in range(self.nh):            thetax = 0.0            for i in range(self.ni):                thetax = thetax + self.ai[i]*self.wh[i][j]            self.ah[j] = sigmoid(thetax)        for j in range(self.no):            thetax = 0.0            for i in range(self.nh):                thetax = thetax + self.ah[i]*self.wo[i][j]            self.ao[j] = sigmoid(thetax)        return self.ao    #反向传播,链式法则,通过梯度下降改变权值    def back_propagate(self, targets, n, m):        self.eo = zeros((self.no, 1))        self.eh = zeros((self.nh, 1))        for i in range(self.no):            error = targets[i] - self.ao[i]            self.eo[i] = dsigmoid(self.ao[i]) * error        for i in range(self.nh):            error = 0.0            for j in range(self.no):                error = error + self.eo[j]*self.wo[i][j]            self.eh[i] = dsigmoid(self.ah[i]) * error        for i in range(self.nh):            for j in range(self.no):                delta = self.eo[j]*self.ah[i]                self.wo[i][j] = self.wo[i][j] + n*delta + m*self.co[i][j] #最后这一项是动量项                self.co[i][j] = delta        for i in range(self.ni):            for j in range(self.nh):                delta = self.eh[j]*self.ai[i]                self.wh[i][j] = self.wh[i][j] + n*delta + m*self.ch[i][j]                self.ch[i][j] = delta        error = 0.0        for i in range(self.no):            error = error + 0.5*(targets[i] - self.ao[i])**2 #这里的代价函数是平方误差函数        return error    def test(self, xTest, yTest):        print 'Test:'        for i in range(shape(xTest)[0]):            ans = self.forward_propagate(xTest[i])            print yTest[i], '--', ans    def train(self, xTrain, yTrain, k = 1000, n = 0.07, m = 0.05):        aplta = n        for p in range(k):            error = 0.0            Num = shape(xTrain)[0]            for i in range(Num):                self.forward_propagate(xTrain[i])                error = error + self.back_propagate(mat(yTrain[i]).T, aplta, m)            if error > self.prerror and p != 0: #如果这一步误差变大,则该步作废,并减小步长                aplta *= 0.9                self.wh = self.wh_copy                self.wo = self.wo_copy            else:                aplta *= 1.1 #否则,适当增大步长,加快拟合                self.prerror = error                self.wh_copy = self.wh                self.wo_copy = self.wo            if p % 100 == 0:                print 'Error --> %f' %error            if error < 0.8:                print 'Error --> %f' %error                print 'It\'s enough.'                break    def show(self):        print 'Hidden theta:'        for each in self.wh.T:            print each        print 'Output theta:'        for each in self.wo.T:            print each#这里将数据集分成了10份,每份轮流当测试集def network(xArr, yArr):    for i in range(10):        xTrain, yTrain, xTest, yTest = get_data_test(xArr, yArr, i)        print '%dth test:' %(i+1)        keys = NN(3, 3, 1)        keys.train(xTrain, yTrain)        keys.test(xTest, yTest)        keys.show()#这里没有求均值项;感觉因为隐藏层的单元都是等价的,所以并不能求均值,会出错。弱渣不是很清楚if __name__ == '__main__':    xArr, yArr = loadDataSet('testSet.txt')    network(xArr, yArr)
以下是数据,一共100个:

-0.01761214.0530640-1.3956344.6625411-0.7521576.5386200-1.3223717.15285300.42336311.05467700.4067047.06733510.66739412.7414520-2.4601506.86680510.5694119.5487550-0.02663210.42774300.8504336.92033411.34718313.17550001.1768133.1670201-1.7818719.0979530-0.5666065.74900310.9316351.5895051-0.0242056.1518231-0.0364532.6909881-0.1969490.44416511.0144595.75439911.9852983.2306191-1.693453-0.5575401-0.57652511.7789220-0.346811-1.6787301-2.1244842.67247111.2179169.5970150-0.7339289.0986870-3.642001-1.61808710.3159853.52395311.4166149.6192320-0.3863233.98928610.5569218.29498411.22486311.5873600-1.347803-2.40605111.1966044.95185110.2752219.54364700.4705759.3324880-1.8895679.5426620-1.52789312.1505790-1.18524711.3093180-0.4456783.29730311.0422226.1051551-0.61878710.32098601.1520830.54846710.8285342.6760451-1.23772810.5490330-0.683565-2.16612510.2294565.9219381-0.95988511.55533600.49291110.99332400.1849928.7214880-0.35571510.3259760-0.3978228.05839700.82483913.73034301.5072785.02786610.0996716.8358391-0.34400810.71748501.7859287.7186451-0.91880111.5602170-0.3640094.7473001-0.8417224.11908310.4904261.9605391-0.0071949.07579200.35610712.44786300.34257812.2811620-0.810823-1.46601812.5307776.47680111.29668311.60755900.47548712.0400350-0.78327711.00972500.07479811.0236500-1.3374720.4683391-0.10278113.7636510-0.1473242.87484610.5183899.88703501.0153997.5718820-1.658086-0.02725511.3199442.17122812.0562165.0199811-0.8516334.3756911-1.5100476.0619920-1.076637-3.18188811.82109610.28399003.0101508.4017661-1.0994581.6882741-0.834872-1.7338691-0.8466373.84907511.40010212.62878101.7528425.46816610.0785570.05973610.089392-0.71530011.82566212.69380800.1974459.74463800.1261170.9223111-0.6797971.22053010.6779832.55666610.76134910.6938620-2.1687910.14363211.3886109.34199700.31702914.7390250

0 0
原创粉丝点击