Theano:线性回归

来源:互联网 发布:c语言输出string 编辑:程序博客网 时间:2024/06/06 03:15
import theanofrom theano import tensor as Timport numpy as npimport matplotlib.pyplot as pltdef model(X, w):    return X * wtrX = np.linspace(-1, 1, 101)trY = 2 * trX + np.random.randn(*trX.shape) * 0.3X = T.scalar()Y = T.scalar()W = T.scalar()w = theano.shared(np.asarray(0., dtype=theano.config.floatX))y = model(X, w)cost = T.mean(T.sqr(y - Y))gradient = T.grad(cost=cost, wrt=w)updates = [[w, w - gradient * 0.01]]#define training functiontrain = theano.function(inputs=[X, Y], outputs=cost, updates=updates, allow_input_downcast=True)for i in range(100):    for x, y in zip(trX, trY):        train(x, y)#define predict functionyy = X * Wpredict = theano.function(inputs = [X, W], outputs = yy)testX = np.linspace(-1, 1, 50)testY = 2 * testX + np.random.randn(*testX.shape) * 0.3y_pred = []for x in testX:    yi = predict(x, w.get_value())    y_pred.append(yi)fig = plt.figure()ax = fig.add_subplot(211)ax.plot(trX, trY, '.')plt.title('training data')ax = fig.add_subplot(212)l1 = ax.plot(testX, y_pred,'.r')l2 = ax.plot(testX, testY,'o')plt.title('predict result')ax.legend(('predict data','real data'), 'upper left')plt.show()print w.get_value() #something around 2

线性回归的训练数据和相应的预测结果如下图

这里写图片描述
用theano训练的线性回归模型,上一幅图为训练数据,下一幅图为真实数据和预测结果的对比

0 0
原创粉丝点击