caffe学习系列2—步骤记录
来源:互联网 发布:qq农场辅助软件 编辑:程序博客网 时间:2024/05/17 03:21
1.为训练做准备工作
参考链接:http://blog.csdn.net/Losteng/article/details/50799998?ref=myread
先写上caffe训练命令,如下:
sudo /home/caffe/caffe/build/tools/caffe train -solver data/myfile/solver.prototxt
其中solver.prototxt文件会调用train_net.prototxt网络,并且train_net.prototxt中需要写数据集的路径,所以需要做以下准备工作:
1)准备caffe支持的数据格式(以lmdb格式为例:)
同上,先贴上生成lmdb格式的命令代码:
#!/usr/bin/env shDATA=images/trainrm -rf $DATA/img_train_lmdbbuild/tools/convert_imageset --shuffle \--resize_height=256 --resize_width=256 \/home/caffe/caffe/images/train $DATA/train.txt $DATA/img_train_lmdb
参数解释:
设置参数-shuffle,打乱图片顺序。设置参数-resize_height和-resize_width将所有图片尺寸都变为256*256./home/caffe/caffe/images/train/ 为图片保存的绝对路径。可以根据实际情况将路径修改。运行脚本,img_train_lmdb数据生成。
其中:convert_imageset 参数有:
由上面命令可知,还需要train.txt(文件名(亲测可以含路径)与标签清单):linux下可参考:
##(修改至符合自己需求)#!/usr/bin/env shDATA=images/trainecho "Create train.txt..."rm -rf $DATA/train.txtfind $DATA -name cat*.jpg | cut -d '/' -f4 | sed "s/$/ 1/">>$DATA/train.txtfind $DATA -name dog*.jpg | cut -d '/' -f4 | sed "s/$/ 2/">>$DATA/tmp.txtcat $DATA/tmp.txt>>$DATA/train.txtrm -rf $DATA/tmp.txtecho "Done.."
windows下可参考:
http://blog.csdn.net/u012617944/article/details/78128218
然后按照最初的命令生成就好了!
2)计算均值文件
$TOOLS/compute_image_mean $EXAMPLE/train_lmdb \ $DATA/imagenet_mean.binaryproto
其中:EXAMPLE=/caffe/examples/lmdb_test/train //你的train_lmdb所在路径
DATA=/caffe/examples/lmdb_test/train //生成的均值文件存放路径
TOOLS=/caffe/build/tools //caffe的tools路径
3)写训练需要的xx_train_test.prototxt和solver.prototxt文件
xx_deploy.prototxt:设置网络中间层的结构。data层仅定义4D的input_dim(分别表示batch大小,通道数,滤波器高度,滤波器宽度),最后一层没有loss层。提取特征或预测输出时使用;
xx_solver.prototxt:设置训练网络所需的网络结构文件(xx_train_test.prototxt)和超参数,训练网络时使用;
xx_train_test.prototxt:设置网络每层的结构。data层中include的phase为TRAIN或TEST区分是输入数据是训练数据还是测试数据。data层有完整的定义,最后一层为loss层,训练和测试网络时都用。
网络模型测试见我的另一篇博客笔记:
http://blog.csdn.net/u012617944/article/details/78264639
4)CSV文件转HDF5格式
(LMDB或LevelDB提供训练的标签只能是标量,无法提供向量或是矩阵形式,比如人脸关键点及多标签问题,这时往往考虑用hdf5格式)
贴上我自己的代码:以人脸关键点进行表情识别为例:我的数据102维关键点data和6种表情标签label;
def load_data(): print "ssssssssssssss" dataframe = read_csv(os.path.expanduser(FTRAIN)) # load pandas dataframe print "enter read_csv" #从csv转换成pd支持的dataframe(数据表),分别提取数据和标签 train_feature = dataframe.iloc[:, 0:102] train_label = dataframe.iloc[:, 102] print len(train_label) #将df(的分别数据和标签)转成numpy的array labels_arr = np.zeros((len(train_label),6)) for count,label_arr in enumerate(train_label): labels_arr[count][label_arr] = 1 print(labels_arr) data_arr = np.vstack(train_feature.values) data_arr = data_arr.reshape(-1,1,1,102)#对数据产生的数组reshape # print(data_arr) #new_data = {} #new_data['input'] = np.reshape(train_feature, (-1,1,1,102)) #new_data['output'] = labels_arr #new_data['input'], new_data['output'] = shuffle(new_data['input'], new_data['output'],random_state = 0) data_arr, labels_arr = shuffle(data_arr, labels_arr, random_state = 0) return data_arr, labels_arrdef save_data_as_hdf5(hdf5_data_filename, data, label): f = h5py.File(hdf5_data_filename,'w') f['data'] = data.astype(np.float32) f['label'] = label.astype(np.float32) f.close()def main(): print "this is main" hdf5_data_filename = 'ck_train_data.hdf5' data, label = load_data() print(data) save_data_as_hdf5(hdf5_data_filename, data, label)#最后这句不要忘记,不然找不到主函数if __name__ == "__main__": main()
补充知识点:
dataframe操作:http://blog.csdn.net/u014607457/article/details/51290582reshape函数(操作numpy的array)参数介绍:https://www.zhihu.com/question/52684594numpy.zeros函数使用方法:http://blog.csdn.net/qq_26948675/article/details/54318917
参考代码:http://blog.csdn.net/u011762313/article/details/48851015
参考项目链接:(python+caffe实现的)
https://github.com/FranckDernoncourt/caffe_demos/blob/master/iris/iris_tuto.py
代码讲解:http://m.blog.csdn.net/shadow_guo/article/details/50382446
- caffe学习系列2—步骤记录
- caffe源码学习记录
- caffe 菜鸟学习记录
- 【Caffe安装】caffe安装系列——史上最详细的安装步骤
- caffe学习系列
- Caffe学习系列****
- DL学习笔记【2】caffe使用步骤详解
- 【caffe】Caffe学习系列:solver及其配置
- Caffe学习系列(3):im2col
- caffe 学习系列之finetuning
- caffe 学习系列 视觉层
- caffe学习系列:命令行解析
- Caffe学习系列:caffemodel可视化
- caffe学习系列--层解读
- caffe学习系列:数据增强
- caffe学习系列:网络融合
- caffe系列学习博客网址
- 深度学习Imagenet caffe AlexNet 实验步骤
- \python2.7\lib\site-packages\sklearn\cross_validation.py:41: DeprecationWarning: This module was dep
- conda命令
- Hello world!
- 欢迎使用CSDN-markdown编辑器
- RaspberryPi实验
- caffe学习系列2—步骤记录
- 基于树莓派的apache2服务器搭建
- 利用单例+观察者设计一个简易的分发/订阅消息机制
- GA--sentence match
- Ubuntu14.04安装搜狗输入法
- 欢迎使用CSDN-markdown编辑器
- Java内部类详解
- Redis
- MySQL中utf8和utf8mb4的区别