neural-networks-and-deep-learning multiple_eta.py

来源:互联网 发布:mac远程桌面连接win10 编辑:程序博客网 时间:2024/06/06 04:50

其实这个文件也相对简单,首先就是确定不同的三个eta,也就是三个不同的学习率,然后训练不同的网络,权重初始化的方式为默认方式,然后进行训练,经过30个epoch的训练后的结果纪录在result中。

在plot的函数中只画出validation_cost。

可以看出来eta大了,学不会。eta小了学的慢。

loss for different learning rate

"""multiple_eta~~~~~~~~~~~~~~~This program shows how different values for the learning rate affecttraining.  In particular, we'll plot out how the cost changes usingthree different values for eta."""# Standard libraryimport jsonimport randomimport sys# My librarysys.path.append('../src/')import mnist_loaderimport network2# Third-party librariesimport matplotlib.pyplot as pltimport numpy as np# ConstantsLEARNING_RATES = [0.025, 0.25, 2.5]COLORS = ['#2A6EA6', '#FFCD33', '#FF7033']NUM_EPOCHS = 30def main():    run_networks()    make_plot()def run_networks():    """Train networks using three different values for the learning rate,    and store the cost curves in the file ``multiple_eta.json``, where    they can later be used by ``make_plot``.    """    # Make results more easily reproducible    random.seed(12345678)    np.random.seed(12345678)    training_data, validation_data, test_data = mnist_loader.load_data_wrapper()    results = []    for eta in LEARNING_RATES:        print "\nTrain a network using eta = "+str(eta)        net = network2.Network([784, 30, 10])        results.append(            net.SGD(training_data, NUM_EPOCHS, 10, eta, lmbda=5.0,                    evaluation_data=validation_data,                     monitor_training_cost=True))    f = open("multiple_eta.json", "w")    json.dump(results, f)    f.close()def make_plot():    f = open("multiple_eta.json", "r")    results = json.load(f)    f.close()    fig = plt.figure()    ax = fig.add_subplot(111)    for eta, result, color in zip(LEARNING_RATES, results, COLORS):        _, _, training_cost, _ = result        ax.plot(np.arange(NUM_EPOCHS), training_cost, "o-",                label="$\eta$ = "+str(eta),                color=color)    ax.set_xlim([0, NUM_EPOCHS])    ax.set_xlabel('Epoch')    ax.set_ylabel('Cost')    plt.legend(loc='upper right')    plt.show()if __name__ == "__main__":    main()
0 0
原创粉丝点击