python学习笔记(三)绘制训练过程的loss和accuracy曲线

来源:互联网 发布:阿拉伯 联邦 知乎 编辑:程序博客网 时间:2024/06/05 19:18

0、参考文献

[1] http://blog.csdn.net/u013078356/article/details/51154847

[2] http://blog.csdn.net/YhL_Leo/article/details/51774966

1、记录训练日志

在训练过程中的命令中加入一行参数 ,实现Log日志的记录

其中目录改成自己项目的目录,这样训练结束之后,会在Log文件夹中生成每次训练的Log日志

#!/bin/bashGLOG_logtostderr=0 GLOG_log_dir=fine-grained/Log/ caffe.bin train --solver fine-grained/solver.prototxt --weights fine-grained/bvlc_googlenet.caffemodel 

2、画图

把生成的日志重命名为log.txt,用jupyter notebook画图,代码如下

import osimport sysimport numpy as npimport matplotlib.pyplot as pltimport mathimport reimport pylabfrom pylab import figure, show, legendfrom mpl_toolkits.axes_grid1 import host_subplot# read the log filefp = open('log.txt', 'r')train_iterations = []train_loss = []test_iterations = []test_accuracy = []for ln in fp:  # get train_iterations and train_loss  if '] Iteration ' in ln and 'lr = ' in ln:    arr = re.findall(r'ion \b\d+\b,',ln)    train_iterations.append(int(arr[0].strip(',')[4:]))   if 'iters),' in ln and 'loss = ' in ln:    train_loss.append(float(ln.strip().split(' = ')[-1]))  # get test_iteraitions  if '] Iteration' in ln and 'Testing net (#0)' in ln:    arr = re.findall(r'ion \b\d+\b,',ln)    test_iterations.append(int(arr[0].strip(',')[4:]))  # get test_accuracy  if '#8:' in ln and 'loss3/top-5' in ln:    test_accuracy.append(float(ln.strip().split(' = ')[-1]))fp.close()host = host_subplot(111)plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot windowpar1 = host.twinx()# set labelshost.set_xlabel("iterations")host.set_ylabel("log loss")par1.set_ylabel("validation accuracy")# plot curvesp1, = host.plot(train_iterations, train_loss, label="training log loss")p2, = par1.plot(test_iterations, test_accuracy, label="validation accuracy")# set location of the legend, # 1->rightup corner, 2->leftup corner, 3->leftdown corner# 4->rightdown corner, 5->rightmid ...host.legend(loc=5)# set label colorhost.axis["left"].label.set_color(p1.get_color())par1.axis["right"].label.set_color(p2.get_color())# set the range of x axis of host and y axis of par1host.set_xlim([-200, 5200])par1.set_ylim([-0.1, 1.1])plt.draw()plt.show()

结果:



阅读全文
0 0