基于pycaffe从零开始写mnist(第四篇)——生成train_loss图片

来源:互联网 发布:mac自带远程桌面连接 编辑:程序博客网 时间:2024/06/05 17:12

将train_loss可视化,有利于更好的查看train_loss的变化趋势

主要参考这位大神的博客:http://adilmoujahid.com/posts/2016/06/introduction-deep-learning-python-caffe/



step1:生成训练日志:

主要参考这个链接:http://blog.csdn.net/u012746763/article/details/51823974


将caffe_HOME/examples/imagenet/train_caffenet.sh复制到当前项目的文件夹下,然后根据上面这个csdn这个博客进行修改:


修改结果如下:

#!/usr/bin/env sh#从/home/xuy/caffe/examples/imagenet/train_caffenet.sh进行更改set -e#你写的每个脚本都应该在文件开头加上set -e,这句语句告诉bash如果任何语句的执行结果不是true则应该退出。这样的好处是防止错误像滚雪球般变大导致一个致命的错误,而这些错误本应该在之前就被处理掉。如果要增加可读性,可以使用set -o errexit,它的作用与set -e相同。LOG=mnist/train.log#用来保存生成的log路径/home/xuy/caffe/build/tools/caffe train \    --solver=mnist/solver.prototxt  -gpu all  2>&1   | tee $LOG#/home/xuy/caffe/build/tools/caffe train \#    --solver=mnist/solver.prototxt $@

训练好之后,会在mnist文件夹下生成train.log文件,并且会产生caffemodel,因为在prototxt里面写了每多少次训练,保存一次caffemodel,以及保存的路径


step2:进行读取log日志,并且绘制train_loss.png图片


在这里需要注意的是:在读取log的时候,会调用caffe自带的shell脚本,刚刚安装好的时候,/home/xuy/caffe/tools/extra/parse_log.sh,以及该文件夹下面的其他python脚本,shell脚本,文件读写权限不够,因此需要chmod 777 *.sh,*.py来修改文件读写权限


下面的python代码首先通过调用shell 命令对于log文件进行解析,然后绘制出train_log图像


莫名其妙,显示不出来图片,不过保存在某一个路径下,直接打开看也是可以的吧

# -*- coding: utf-8 -*-__author__ = 'xuy''''Title           :plot_learning_curve.pyDescription     :This script generates learning curves for caffe modelsAuthor          :Adil MoujahidDate Created    :20160619Date Modified   :20160619version         :0.1usage           :python plot_learning_curve.py model_1_train.log ./caffe_model_1_learning_curve.pngpython_version  :2.7.11'''import osimport sysimport subprocessimport pandas as pdfrom PIL import Imageimport matplotlibmatplotlib.use('Agg')import matplotlib.pylab as pltplt.style.use('ggplot')import numpy as npcaffe_path = '/home/xuy/caffe/'# model_log_path = sys.argv[1]#需要训练的日志的路径,因此通过系统命令记录一下训练的日志model_log_path = '/home/xuy/桌面/code/python/caffe/python_mnist/mnist/train.log'# learning_curve_path = sys.argv[2]#将训练的train_loss保存到某一个路径当中learning_curve_path = '/home/xuy/桌面/code/python/caffe/python_mnist/mnist/my_train_loss.png'#Get directory where the model logs is saved, and move to itmodel_log_dir_path = os.path.dirname(model_log_path)os.chdir(model_log_dir_path)'''Generating training and test logs'''#Parsing training/validation logscommand = caffe_path + 'tools/extra/parse_log.sh ' + model_log_path#利用系统命令来读取并且解析生成的日志文件process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)process.wait()#Read training and test logstrain_log_path = model_log_path + '.train'test_log_path = model_log_path + '.test'train_log = pd.read_csv(train_log_path, delim_whitespace=True)test_log = pd.read_csv(test_log_path, delim_whitespace=True)'''Making learning curve'''fig, ax1 = plt.subplots()#Plotting training and test lossestrain_loss, = ax1.plot(train_log['#Iters'], train_log['TrainingLoss'], color='red',  alpha=.5)#用红线表示train_losstest_loss, = ax1.plot(test_log['#Iters'], test_log['TestLoss'], linewidth=2, color='green')#用绿线表示test_lossax1.set_ylim(ymin=0, ymax=1)ax1.set_xlabel('Iterations', fontsize=15)ax1.set_ylabel('Loss', fontsize=15)ax1.tick_params(labelsize=15)#Plotting test accuracyax2 = ax1.twinx()test_accuracy, = ax2.plot(test_log['#Iters'], test_log['TestAccuracy'], linewidth=2, color='blue')#用蓝线表示accuracyax2.set_ylim(ymin=0, ymax=1)ax2.set_ylabel('Accuracy', fontsize=15)ax2.tick_params(labelsize=15)#Adding legendplt.legend([train_loss, test_loss, test_accuracy], ['Training Loss', 'Test Loss', 'Test Accuracy'],  bbox_to_anchor=(1, 0.8))plt.title('Training Curve', fontsize=18)#Saving learning curve'''_, ax1 = plt.subplots()ax2 = ax1.twinx()ax1.plot(np.arange(niter), train_loss,'g')ax2.plot(test_interval * np.arange(len(test_acc)), test_acc, 'r')ax1.set_xlabel('iteration')ax1.set_ylabel('train loss')ax2.set_ylabel('test accuracy')plt.savefig('/home/xuy/桌面/loss_train.png')#这个一定要写到plt.show()之前plt.show()from PIL import Imageimport matplotlib.pyplot as pltimg=Image.open('d:/dog.png')plt.figure("dog")plt.imshow(img)plt.show()'''plt.savefig(learning_curve_path)train_loss_img=Image.open(learning_curve_path)plt.imshow(train_loss_img)plt.show()'''Deleting training and test logs'''command = 'rm ' + train_log_pathprocess = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)process.wait()command = command = 'rm ' + test_log_pathprocess = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)process.wait()


阅读全文
0 0
原创粉丝点击