CNN+caffe学习4:自己训练网络全过程

来源:互联网 发布:皇甫圣华淘宝 编辑:程序博客网 时间:2024/05/17 22:35

资料下载地址:https://github.com/EmmaW8/caffe.git
branch选择201701,code文件夹里面是需要用到的执行文件,models里面是我自己设计的5层CNN模型配置文件,以及训练结果。

1. 准备数据

2. 数据分类

get-filename-label.py文件
我的文件路径比较复杂,此处很麻烦,如果你不用我的数据,可以不用我这个执行文件,自己写一个就好,很easy~
注意修改内部数据文件路径

$ python get-filename-label.py

会生成3个文件,分别为train.txt,val.txt,test.txt.
train和val里面有label,训练用,test无label,分类时用。
train:val=9:1,你也可以尝试4:1.

这里其实很简单,就是获取数据路径名称,以及其label,我的文件仅供参考,可以根据这个写出你自己的Python提取文件。

3. 生成lmdb文件,生成均值文件(mean.binaryproto)

creat_atypia.sh文件
修改DATA路径。
运行:

$ sh creat_atypia.sh

生成3个lmdb文件和一个mean文件,这里用到caffe自己提供的工具来生成。
示例;

echo "Creating train lmdb..."GLOG_logtostderr=1 $TOOLS/convert_imageset \    --resize_height=$RESIZE_HEIGHT \    --resize_width=$RESIZE_WIDTH \    --shuffle \    $TRAIN_DATA_ROOT \    $DATA/train.txt \    $DATA/atypia_train_initialsize_$DBTYPE

利用convert_imageset来生成lmdb文件

echo "Computing image mean..."$TOOLS/compute_image_mean -backend=$DBTYPE \  $DATA/atypia_train_initialsize_$DBTYPE $DATA/initialsize_mean.binaryproto

利用compute_image_mean和刚生成的train_lmdb文件生成mean文件。

4. 自己设计网络结构

我设计的网络结构在models文件夹里,-solver.prototxt是solver文件,-train_val.prototxt是描述网络框架文件,*.caffemodel是我训练完毕后生成的模型文件。deploy.prototxt是用于test分类时的文件。

在训练之前,需要把solver 和 train_val.prototxt文件配置好。
具体如何配置,请参考我的文件进行修改。注意修改路径。

此处一定要注意是layer {},还是layers{},如果定义的layer不同,那么内部的参数会不同,上次就栽在这个坑里了。layers里面的type都是大写的,并且param的设置不同。一个prototxt文档里,只能出现一种,不能layer和layers混用。

5. 训练

执行code里面的可执行文件train-atypia.sh

$ sh train-atypia.sh

开始训练。

文件说明

#!/usr/bin/env sh# begin trainset -e# 指定log文件路径和名称LOG=/home/emma/software/caffe/models/atypia/log/train-`date +%Y-%M-%d-%H-%M-%S`.log#指定一些变量名,便于后续开发TOOLS=/home/emma/software/caffe/build/toolsMODEL=/home/emma/software/caffe/models/atypiaNAME=atypia# fine-tune 别人的模型,此处不需要#$TOOLS/caffe train  -solver $MODEL/$NAME-solver.prototxt  -weights $MODEL/VGG_16_layers.caffemodel -gpu 0 > VGG2.log# 训练自己的模型,并将控制台输出导入到log文件$TOOLS/caffe train -solver $MODEL/$NAME-solver.prototxt -gpu 0 2>&1 | tee $LOG

6. 可视化,画出训练loss,以及test accuracy

训练完毕后,可以得到log文件,我的log文件位于models/log/
对log文件进行处理,提取出accuracy和loss的数值,然后画图。
可执行文件:plot.sh

./plot_training_log.py.example 0 /home/emma/software/caffe/models/atypia/log/curve/test_accuracy.png /home/emma/software/caffe/models/atypia/log/train-2017-52-07-19-52-39.log

0:代表画出test_accuracy 和iters的曲线
2:代表画出test_loss和iters的曲线
6:代表画出train_loss 和iters的曲线

第二个参数是,图片保存位置,第三个参数log文件所在的位置

我生成的曲线存储在models/log/curve里面,例如
这里写图片描述

注意事项:

如果loss未收敛,记得及时调整学习率。

第一次使用自己设计的网络时,选择10个训练数据,并且训练数据=val数据集,看一下训练几次后准确率是否会达到100%,如果会,就说明模型可用!!!


EMMA
SIAT

0 0
原创粉丝点击