caffe 08 win10 用python脚本画loss曲线

来源:互联网 发布:手机ftp软件 编辑:程序博客网 时间:2024/06/17 08:31
# d:\git\DeepLearning\caffe\loss.py# 把loss.py文件放到caffe根目录下运行,否则,要调整相对路径及prototxt里文件路径# 用python画loos的曲线图# 源自 http://edu.csdn.net/course/detail/3506 视频import sys,os;import numpy as np;import matplotlib.pyplot as plt;# 引入caffe的Python接口caffe_root='D:/git/DeepLearning/caffe/build/x64/install/';sys.path.insert(0, caffe_root+'python');import caffe;# 指定计算设备# caffe.set_mode_cpu(); # 指定使用CPU计算caffe.set_mode_gpu(); # 指定使用GPU计算# caffe.set_device(0) # 明确指定使用那个gpu设备solver = caffe.SGDSolver('examples/mnist/lenet_solver.prototxt');niter = 1000;test_interval = 200;train_loss = np.zeros(niter);test_acc = np.zeros(int(np.ceil(niter / test_interval)));# the main solver loopfor it in range(niter):    solver.step(1) # SGD by Caffe    # 保存训练的loss值    train_loss[it] = solver.net.blobs['loss'].data;    solver.test_nets[0].forward(start = 'conv1');    if it % test_interval == 0:        acc = solver.test_nets[0].blobs['accuracy'].data;        print('Iteration', it, 'testing...', 'accuracy:', acc);        test_acc[it // test_interval] = acc;print(test_acc);_, ax1 = plt.subplots();ax2 = ax1.twinx();ax1.plot(np.arange(niter), train_loss);ax2.plot(test_interval * np.arange(len(test_acc)), test_acc, 'r');# 设置曲线标签ax1.set_xlabel('iteration');ax1.set_ylabel('train loss');ax2.set_ylabel('test accuracy');# 显示曲线plt.show();

0 0