Windows+Python3下绘制Caffe训练日志中的Loss和Accuracy曲线图
来源:互联网 发布:php 微信群发接口 编辑:程序博客网 时间:2024/06/11 00:21
在深度学习中,可以通过学习曲线评估当前训练状态:
- train loss 不断下降,test loss 不断下降,说明网络仍然在认真学习中。
- train loss 不断下降,test loss 趋于不变,说明网络过拟合。
- train loss 趋于不变,test loss 趋于不变,说明学习遇到瓶颈,需减小学习速率或者批量数据尺寸。
- train loss 趋于不变,test loss 不断下降,说明数据集 100% 有问题。
- train loss 不断上升,test loss不断上升(最终为NaN),可能网络结构设计不当、训练超参数设置不当、程序bug等某个问题引起,需要进一步定位。
Linux下的MATLAB代码:
// 提取log文件中的loss值shell命令:cat train_log_file | grep ”Train net output ” | awk ‘{print $11}’
clear;clc;close all;train_log_file = 'train.log';train_interval = 100;test_interval = 200;[~, train_string_output] = dos(['cat ', train_log_file, ' | grep ''Train net output #0'' | awk ''{print $11}''']);train_loss = str2num(train_string_output);n = 1 : length(train_loss);idx_train = (n - 1) * train_interval;[~, test_string_output] = dos(['cat ', train_log_file, ' | grep ''Test net output #1'' | awk ''{print $11}''']);test_loss = str2num(test_string_output);m = 1 : length(test_loss);idx_test = (m - 1) * test_interval;figure;plot(idx_train, train_loss);hold on;plot(idx_test, test_loss);grid on;legend('Train Loss', 'Test Loss');xlabel('iterations');ylabel('loss');title(' Train & Test Loss Curve');
Window下的Python3(Anaconda3+Pycharm)代码:
"./bin/caffe.exe" train --solver=./examples/mnist/lenet_solver.prototxt >./examples/mnist/log/mnist_Lenet_train_test.log 2>&1pause
命令>./examples/mnist/log/mnist_Lenet_train_test.log 2>&1
表示训练日志的输出。
parse_log.py和extract_seconds.py文件用于解析训练日志:
parse_log.py源码:
import refrom examples.mnist.log.extract_seconds import *import csvfrom collections import OrderedDictdef parse_log(log_file_name): """ Parse log file :param log_file_name: the name of log file :return: (train_dict_list, test_dict_list) """ regex_iteration = re.compile('Iteration (\d+)') regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([.\deE+-]+)') regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([.\deE+-]+)') regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)') # Pick out lines of interest iteration = -1 learning_rate = float('NaN') train_dict_list = [] test_dict_list = [] train_row = None test_row = None logfile_year = get_log_created_year(log_file_name) with open(log_file_name) as f: start_time = get_start_time(f, logfile_year) last_time = start_time for line in f: iteration_match = regex_iteration.search(line) if iteration_match: iteration = float(iteration_match.group(1)) if iteration == -1: # Only start parsing for other stuff if we've found the first iteration continue try: time = extract_datetime_from_line(line, logfile_year) except ValueError: # Skip lines with bad formatting, for example when resuming solver continue # if it's another year if time.month < last_time.month: logfile_year += 1 time = extract_datetime_from_line(line, logfile_year) last_time = time seconds = (last_time - start_time).total_seconds() learning_rate_match = regex_learning_rate.search(line) if learning_rate_match: learning_rate = float(learning_rate_match.group(1)) train_dict_list, train_row = parse_line_for_net_output( regex_train_output, train_row, train_dict_list, line, iteration, seconds, learning_rate) test_dict_list, test_row = parse_line_for_net_output( regex_test_output, test_row, test_dict_list, line, iteration, seconds, learning_rate) fix_initial_nan_learning_rate(train_dict_list) fix_initial_nan_learning_rate(test_dict_list) return train_dict_list, test_dict_listdef parse_line_for_net_output(regex_obj, row, row_dict_list, line, iteration, seconds, learning_rate): """Parse a single line for training or test output Returns a a tuple with (row_dict_list, row) row: may be either a new row or an augmented version of the current row row_dict_list: may be either the current row_dict_list or an augmented version of the current row_dict_list """ output_match = regex_obj.search(line) if output_match: if not row or row['NumIters'] != iteration: # Push the last row and start a new one if row: # If we're on a new iteration, push the last row # This will probably only happen for the first row; otherwise # the full row checking logic below will push and clear full # rows row_dict_list.append(row) row = OrderedDict( [ ('NumIters', iteration), ('Seconds', seconds), ('LearningRate', learning_rate) ] ) # output_num is not used; may be used in the future output_name = output_match.group(2) output_val = output_match.group(3) row[output_name] = float(output_val) if row and len(row_dict_list) >= 1 and len(row) == len(row_dict_list[0]): # The row is full, based on the fact that it has the same number of columns as the first row; # append it to the list row_dict_list.append(row) row = None return row_dict_list, rowdef fix_initial_nan_learning_rate(dict_list): """Correct initial value of learning rate Learning rate is normally not printed until after the initial test and training step, which means the initial testing and training rows have LearningRate = NaN. Fix this by copying over the LearningRate from the second row, if it exists. """ if len(dict_list) > 1: dict_list[0]['LearningRate'] = dict_list[1]['LearningRate']def save_csv_files(logfile, output_dir, train_dict_list, test_dict_list, delimiter=',', verbose=False): """Save CSV files to output_dir If the input log file is, e.g., caffe.INFO, the names will be caffe.INFO.train and caffe.INFO.test """ log_basename = os.path.basename(logfile) train_filename = os.path.join(output_dir, log_basename + '.train') write_csv(train_filename, train_dict_list, delimiter, verbose) test_filename = os.path.join(output_dir, log_basename + '.test') write_csv(test_filename, test_dict_list, delimiter, verbose)def write_csv(output_filename, dict_list, delimiter, verbose=False): """Write a CSV file """ if not dict_list: if verbose: print('Not writing %s; no lines to write' % output_filename) return dialect = csv.excel dialect.delimiter = delimiter with open(output_filename, 'w') as f: dict_writer = csv.DictWriter(f, fieldnames=dict_list[0].keys(),dialect=dialect) dict_writer.writeheader() dict_writer.writerows(dict_list) if verbose: print('Wrote %s' % output_filename)def main(): log_file_name = 'mnist_Lenet_train_test.log' output_dir = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'//解析后的文件保存地址 train_dict_list, test_dict_list = parse_log(log_file_name) save_csv_files(log_file_name, output_dir, train_dict_list, test_dict_list, delimiter=',')if __name__ == '__main__': main()
extract_seconds.py源码:
import datetimeimport osdef extract_datetime_from_line(line, year): """ extract datetime from line :param line: the lines :param year: the year :return: datetime """ # Expected format: I0210 13:39:22.381027 25210 solver.cpp:204] Iteration 100, lr = 0.00992565 line = line.strip().split() month = int(line[0][1:3]) day = int(line[0][3:]) timestamp = line[1] pos = timestamp.rfind('.') ts = [int(x) for x in timestamp[:pos].split(':')] hour = ts[0] minute = ts[1] second = ts[2] microsecond = int(timestamp[pos + 1:]) dt = datetime.datetime(year, month, day, hour, minute, second, microsecond) return dtdef get_log_created_year(input_file): """ get the year from log file system timestamp :param input_file: the input :return: the created year of the log file """ log_created_time = os.path.getctime(input_file) log_created_year = datetime.datetime.fromtimestamp(log_created_time).year return log_created_yeardef get_start_time(line_iterable, year): """ find start time from group of lines :param line_iterable: the lines of log file :param year: the created year of log file :return: the start datetime """ start_datetime = None for line in line_iterable: line = line.strip() if line.find('Solving') != -1: start_datetime = extract_datetime_from_line(line, year) break return start_datetime
绘图源码:
import matplotlib.pyplot as pltimport randomimport itertoolsdef load_data(data_file, phase): """ load the data :param data_file: the data file :param phase: the data of train phase or test phase :return: data """ if phase == 'Train': data = [[], [], []] with open(data_file, 'r') as f: for line in itertools.islice(f, 2, None, 2): line = line.strip() fields = line.split(",") data[0].append(float(fields[0].strip())) data[1].append(float(fields[2].strip())) data[2].append(float(fields[3].strip())) else: data = [[], [], [], []] with open(data_file, 'r') as f: for line in itertools.islice(f, 2, None, 2): line = line.strip() fields = line.split(",") data[0].append(float(fields[0].strip())) data[1].append(float(fields[2].strip())) data[2].append(float(fields[3].strip())) data[3].append(float(fields[4].strip())) return datadef plot_chart(path_to_png, data, phase): """ plot the chart according the log file :param path_to_png: the save path of the png chart :param data: the data of chart :param phase: plot the chart of train phase or test phase :return: None """ line_width = 1.0 # the line width if phase == 'Train': train_num_iteration = data[0] train_learning_rate = data[1] train_loss = data[2] # plot the Iteration VS Loss train_color = [random.random(), random.random(), random.random()] # the color of line figure_1 = plt.figure('Train Iterations VS Loss') plt.plot(train_num_iteration, train_loss, color=train_color, linewidth=line_width) plt.title('Train Iterations VS Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.savefig(path_to_png + 'Train Iterations VS Loss.png') # plot the Iteration VS learning rate train_color = [random.random(), random.random(), random.random()] # the color of line figure_2 = plt.figure('Train Iterations VS LearningRate') plt.plot(train_num_iteration, train_learning_rate, color=train_color, linewidth=line_width) plt.title('Train Iterations VS LearningRate') plt.xlabel('Iterations') plt.ylabel('LearningRate') plt.savefig(path_to_png + 'Train Iterations VS LearningRate.png') else: test_num_iteration = data[0] test_learning_rate = data[1] test_accuracy = data[2] test_loss = data[3] # plot the Iteration VS Loss test_color = [random.random(), random.random(), random.random()] # the color of line figure_1 = plt.figure('Test Iterations VS Loss') plt.plot(test_num_iteration, test_loss, color=test_color, linewidth=line_width) plt.title('Test Iterations VS Loss') plt.xlabel('Iterations') plt.ylabel('Loss') plt.savefig(path_to_png + 'Test Iterations VS Loss.png') # plot the Iteration VS LearningRate test_color = [random.random(), random.random(), random.random()] # the color of line figure_2 = plt.figure('Test Iterations VS LearningRate') plt.plot(test_num_iteration, test_learning_rate, color=test_color, linewidth=line_width) plt.title('Test Iterations VS LearningRate') plt.xlabel('Iterations') plt.ylabel('LearningRate') plt.savefig(path_to_png + 'Test Iterations VS LearningRate.png') # plot the Iteration VS Accuracy test_color = [random.random(), random.random(), random.random()] # the color of line figure_3 = plt.figure('Test Iterations VS Accuracy') plt.plot(test_num_iteration, test_accuracy, color=test_color, linewidth=line_width) plt.title('Test Iterations VS Accuracy') plt.xlabel('Iterations') plt.ylabel('Accuracy') plt.savefig(path_to_png + 'Test Iterations VS Accuracy.png')def main(): train_log = 'mnist_Lenet_train_test.log.train' test_log = 'mnist_Lenet_train_test.log.test' path_to_png = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\' # load the train data train_data = load_data(train_log, phase='Train') # plot the train chart plot_chart(path_to_png, train_data, phase='Train') # load the test data test_data = load_data(test_log, phase='Test') # plot the test chart plot_chart(path_to_png, test_data, phase='Test')if __name__ == '__main__': main()
1 0
- Windows+Python3下绘制Caffe训练日志中的Loss和Accuracy曲线图
- Caffe如何画出训练中的loss曲线图和accuracy曲线图
- 【caffe】在windows下输出训练caffemodel的log日志并画出accuracy和loss曲线
- caffe框架下如何画出loss和accuracy曲线图
- caffe绘制训练过程中的accuracy、loss曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- Caffe的可视化训练:绘制loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- caffe绘制训练过程的loss和accuracy曲线
- Windows-8.1_x64+VS2012+Kinect V2环境配置
- 匿名对象
- 数据的更新
- LDAP网络用户账户
- android studio ndk 开发
- Windows+Python3下绘制Caffe训练日志中的Loss和Accuracy曲线图
- 虚幻4---打造3D人物关卡demo
- web标准
- Servlet从定向和路径匹配
- RecyclerView调用notifyDataSetChanged刷新,图片闪烁
- 例题4-1 古老的密码
- LeetCode 152. Maximum Product Subarray 解题报告
- sql sever数据库学习之-----数据更新与删除
- JS-闭包小解析