Windows+Python3下绘制Caffe训练日志中的Loss和Accuracy曲线图

来源:互联网 发布:php 微信群发接口 编辑:程序博客网 时间:2024/06/11 00:21

在深度学习中,可以通过学习曲线评估当前训练状态:

  1. train loss 不断下降,test loss 不断下降,说明网络仍然在认真学习中。
  2. train loss 不断下降,test loss 趋于不变,说明网络过拟合。
  3. train loss 趋于不变,test loss 趋于不变,说明学习遇到瓶颈,需减小学习速率或者批量数据尺寸。
  4. train loss 趋于不变,test loss 不断下降,说明数据集 100% 有问题。
  5. train loss 不断上升,test loss不断上升(最终为NaN),可能网络结构设计不当、训练超参数设置不当、程序bug等某个问题引起,需要进一步定位。

Linux下的MATLAB代码:
// 提取log文件中的loss值shell命令:cat train_log_file | grep ”Train net output ” | awk ‘{print $11}’

clear;clc;close all;train_log_file = 'train.log';train_interval = 100;test_interval = 200;[~, train_string_output] = dos(['cat ', train_log_file, ' | grep ''Train net output #0'' | awk ''{print $11}''']);train_loss = str2num(train_string_output);n = 1 : length(train_loss);idx_train = (n - 1) * train_interval;[~, test_string_output] = dos(['cat ', train_log_file, ' | grep ''Test net output #1'' | awk ''{print $11}''']);test_loss = str2num(test_string_output);m = 1 : length(test_loss);idx_test = (m - 1) * test_interval;figure;plot(idx_train, train_loss);hold on;plot(idx_test, test_loss);grid on;legend('Train Loss', 'Test Loss');xlabel('iterations');ylabel('loss');title(' Train & Test Loss Curve');

Window下的Python3(Anaconda3+Pycharm)代码:

"./bin/caffe.exe" train --solver=./examples/mnist/lenet_solver.prototxt >./examples/mnist/log/mnist_Lenet_train_test.log 2>&1pause

命令>./examples/mnist/log/mnist_Lenet_train_test.log 2>&1表示训练日志的输出。
parse_log.py和extract_seconds.py文件用于解析训练日志:
parse_log.py源码:

import refrom examples.mnist.log.extract_seconds import *import csvfrom collections import OrderedDictdef parse_log(log_file_name):    """    Parse log file    :param log_file_name: the name of log file    :return: (train_dict_list, test_dict_list)    """    regex_iteration = re.compile('Iteration (\d+)')    regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([.\deE+-]+)')    regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([.\deE+-]+)')    regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')    # Pick out lines of interest    iteration = -1    learning_rate = float('NaN')    train_dict_list = []    test_dict_list = []    train_row = None    test_row = None    logfile_year = get_log_created_year(log_file_name)    with open(log_file_name) as f:        start_time = get_start_time(f, logfile_year)        last_time = start_time        for line in f:            iteration_match = regex_iteration.search(line)            if iteration_match:                iteration = float(iteration_match.group(1))            if iteration == -1:                # Only start parsing for other stuff if we've found the first iteration                continue            try:                time = extract_datetime_from_line(line, logfile_year)            except ValueError:                # Skip lines with bad formatting, for example when resuming solver                continue            # if it's another year            if time.month < last_time.month:                logfile_year += 1                time = extract_datetime_from_line(line, logfile_year)            last_time = time            seconds = (last_time - start_time).total_seconds()            learning_rate_match = regex_learning_rate.search(line)            if learning_rate_match:                learning_rate = float(learning_rate_match.group(1))            train_dict_list, train_row = parse_line_for_net_output(                regex_train_output, train_row, train_dict_list, line, iteration, seconds, learning_rate)            test_dict_list, test_row = parse_line_for_net_output(                regex_test_output, test_row, test_dict_list, line, iteration, seconds, learning_rate)        fix_initial_nan_learning_rate(train_dict_list)        fix_initial_nan_learning_rate(test_dict_list)        return train_dict_list, test_dict_listdef parse_line_for_net_output(regex_obj, row, row_dict_list, line, iteration, seconds, learning_rate):    """Parse a single line for training or test output    Returns a a tuple with (row_dict_list, row)    row: may be either a new row or an augmented version of the current row    row_dict_list: may be either the current row_dict_list or an augmented    version of the current row_dict_list    """    output_match = regex_obj.search(line)    if output_match:        if not row or row['NumIters'] != iteration:            # Push the last row and start a new one            if row:                # If we're on a new iteration, push the last row                # This will probably only happen for the first row; otherwise                # the full row checking logic below will push and clear full                # rows                row_dict_list.append(row)            row = OrderedDict(                [                 ('NumIters', iteration),                 ('Seconds', seconds),                 ('LearningRate', learning_rate)                ]            )        # output_num is not used; may be used in the future        output_name = output_match.group(2)        output_val = output_match.group(3)        row[output_name] = float(output_val)    if row and len(row_dict_list) >= 1 and len(row) == len(row_dict_list[0]):        # The row is full, based on the fact that it has the same number of columns as the first row;        # append it to the list        row_dict_list.append(row)        row = None    return row_dict_list, rowdef fix_initial_nan_learning_rate(dict_list):    """Correct initial value of learning rate    Learning rate is normally not printed until after the initial test and    training step, which means the initial testing and training rows have    LearningRate = NaN. Fix this by copying over the LearningRate from the    second row, if it exists.    """    if len(dict_list) > 1:        dict_list[0]['LearningRate'] = dict_list[1]['LearningRate']def save_csv_files(logfile, output_dir, train_dict_list, test_dict_list, delimiter=',', verbose=False):    """Save CSV files to output_dir    If the input log file is, e.g., caffe.INFO, the names will be    caffe.INFO.train and caffe.INFO.test    """    log_basename = os.path.basename(logfile)    train_filename = os.path.join(output_dir, log_basename + '.train')    write_csv(train_filename, train_dict_list, delimiter, verbose)    test_filename = os.path.join(output_dir, log_basename + '.test')    write_csv(test_filename, test_dict_list, delimiter, verbose)def write_csv(output_filename, dict_list, delimiter, verbose=False):    """Write a CSV file    """    if not dict_list:        if verbose:            print('Not writing %s; no lines to write' % output_filename)        return    dialect = csv.excel    dialect.delimiter = delimiter    with open(output_filename, 'w') as f:        dict_writer = csv.DictWriter(f, fieldnames=dict_list[0].keys(),dialect=dialect)        dict_writer.writeheader()        dict_writer.writerows(dict_list)    if verbose:        print('Wrote %s' % output_filename)def main():    log_file_name = 'mnist_Lenet_train_test.log'    output_dir = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'//解析后的文件保存地址    train_dict_list, test_dict_list = parse_log(log_file_name)    save_csv_files(log_file_name, output_dir, train_dict_list, test_dict_list, delimiter=',')if __name__ == '__main__':    main()

extract_seconds.py源码:

import datetimeimport osdef extract_datetime_from_line(line, year):    """    extract datetime from line    :param line: the lines    :param year: the year    :return: datetime    """    # Expected format: I0210 13:39:22.381027 25210 solver.cpp:204] Iteration 100, lr = 0.00992565    line = line.strip().split()    month = int(line[0][1:3])    day = int(line[0][3:])    timestamp = line[1]    pos = timestamp.rfind('.')    ts = [int(x) for x in timestamp[:pos].split(':')]    hour = ts[0]    minute = ts[1]    second = ts[2]    microsecond = int(timestamp[pos + 1:])    dt = datetime.datetime(year, month, day, hour, minute, second, microsecond)    return dtdef get_log_created_year(input_file):    """    get the year from log file system timestamp    :param input_file: the input     :return: the created year of the log file    """    log_created_time = os.path.getctime(input_file)    log_created_year = datetime.datetime.fromtimestamp(log_created_time).year    return log_created_yeardef get_start_time(line_iterable, year):    """    find start time from group of lines    :param line_iterable: the lines of log file    :param year: the created year of log file    :return: the start datetime    """    start_datetime = None    for line in line_iterable:        line = line.strip()        if line.find('Solving') != -1:            start_datetime = extract_datetime_from_line(line, year)            break    return start_datetime

绘图源码:

import matplotlib.pyplot as pltimport randomimport itertoolsdef load_data(data_file, phase):    """    load the data    :param data_file: the data file     :param phase: the data of train phase or test phase    :return: data    """    if phase == 'Train':        data = [[], [], []]        with open(data_file, 'r') as f:            for line in itertools.islice(f, 2, None, 2):                line = line.strip()                fields = line.split(",")                data[0].append(float(fields[0].strip()))                data[1].append(float(fields[2].strip()))                data[2].append(float(fields[3].strip()))    else:        data = [[], [], [], []]        with open(data_file, 'r') as f:            for line in itertools.islice(f, 2, None, 2):                line = line.strip()                fields = line.split(",")                data[0].append(float(fields[0].strip()))                data[1].append(float(fields[2].strip()))                data[2].append(float(fields[3].strip()))                data[3].append(float(fields[4].strip()))    return datadef plot_chart(path_to_png, data, phase):    """    plot the chart according the log file    :param path_to_png: the save path of the png chart    :param data: the data of chart    :param phase: plot the chart of train phase or test phase    :return: None    """    line_width = 1.0 # the line width    if phase == 'Train':        train_num_iteration = data[0]        train_learning_rate = data[1]        train_loss = data[2]        # plot the Iteration VS Loss        train_color = [random.random(), random.random(), random.random()]  # the color of line        figure_1 = plt.figure('Train Iterations VS Loss')        plt.plot(train_num_iteration, train_loss, color=train_color, linewidth=line_width)        plt.title('Train Iterations VS Loss')        plt.xlabel('Iterations')        plt.ylabel('Loss')        plt.savefig(path_to_png + 'Train Iterations VS Loss.png')        # plot the Iteration VS learning rate        train_color = [random.random(), random.random(), random.random()]  # the color of line        figure_2 = plt.figure('Train Iterations VS LearningRate')        plt.plot(train_num_iteration, train_learning_rate, color=train_color, linewidth=line_width)        plt.title('Train Iterations VS LearningRate')        plt.xlabel('Iterations')        plt.ylabel('LearningRate')        plt.savefig(path_to_png + 'Train Iterations VS LearningRate.png')    else:        test_num_iteration = data[0]        test_learning_rate = data[1]        test_accuracy = data[2]        test_loss = data[3]        # plot the Iteration VS Loss        test_color = [random.random(), random.random(), random.random()]  # the color of line        figure_1 = plt.figure('Test Iterations VS Loss')        plt.plot(test_num_iteration, test_loss, color=test_color, linewidth=line_width)        plt.title('Test Iterations VS Loss')        plt.xlabel('Iterations')        plt.ylabel('Loss')        plt.savefig(path_to_png + 'Test Iterations VS Loss.png')        # plot the Iteration VS LearningRate        test_color = [random.random(), random.random(), random.random()]  # the color of line        figure_2 = plt.figure('Test Iterations VS LearningRate')        plt.plot(test_num_iteration, test_learning_rate, color=test_color, linewidth=line_width)        plt.title('Test Iterations VS LearningRate')        plt.xlabel('Iterations')        plt.ylabel('LearningRate')        plt.savefig(path_to_png + 'Test Iterations VS LearningRate.png')        # plot the Iteration VS Accuracy        test_color = [random.random(), random.random(), random.random()]  # the color of line        figure_3 = plt.figure('Test Iterations VS Accuracy')        plt.plot(test_num_iteration, test_accuracy, color=test_color, linewidth=line_width)        plt.title('Test Iterations VS Accuracy')        plt.xlabel('Iterations')        plt.ylabel('Accuracy')        plt.savefig(path_to_png + 'Test Iterations VS Accuracy.png')def main():    train_log = 'mnist_Lenet_train_test.log.train'    test_log = 'mnist_Lenet_train_test.log.test'    path_to_png = 'C:\\Programming Code\\Caffe\\examples\\mnist\\log\\'    # load the train data    train_data = load_data(train_log, phase='Train')    # plot the train chart    plot_chart(path_to_png, train_data, phase='Train')    # load the test data    test_data = load_data(test_log, phase='Test')    # plot the test chart    plot_chart(path_to_png, test_data, phase='Test')if __name__ == '__main__':    main()
1 0