深度学习之Caffe(一) 用c++接口提取特征后用SVM分类
来源:互联网 发布:数据展示平台 编辑:程序博客网 时间:2024/05/19 20:22
深度学习之Caffe(一) 用c++接口提取特征后用SVM分类
转载请私信联系博主,未经同意请勿转载。
最近因为老师的要求接触了一点深度学习和caffe的东西,其中一个task是用ResNet网络将数据集的特征提取出来然后用SVM做分类。作为一个刚接触深度学习和caffe而且编程能力超级薄弱的小白,真的是各种懵。借鉴了一些博客,下面也会贴出来。
目录如下:
- 准备工作之 现成的模型 和 网络
- 用caffe提供的c++接口提取特征
- 将提出的特征转换成Matlab的.mat格式
- 用SVM(LibSVM)做分类
准备工作之 现成的模型 和 网络
本次提特征用的是ResNet50层的网络(网址链接https://github.com/KaimingHe/deep-residual-networks)和已经训练好的模型(模型链接https://onedrive.live.com/?authkey=%21AAFW2-FVoxeVRck&id=4006CBB8476FF777%2117887&cid=4006CBB8476FF777)。 微软的网盘似乎上不了,要在host里面加路径啥的,这里就自行百度吧。 拿到的网络是deploy.prototxt,为了分别用它提取训练集和测试集的特征,需要在前面加上data layer,分别做两个.prototxt,以备提特征的时候用。
用caffe提供的c++接口提取特征
在caffe路径下,./examples/feature_extraction/文件夹下有个readme,给出了特征提取的使用示例:
./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 leveldb
'fc7'是提取特征的层,也是样例模型的最高层,也可以用其它层提取。'leveldb'是存储特征的格式,本次实验我用的是lmdb格式。'examples/_temp/features'是存储特征的路径,注意如果保存成lmdb格式的话,**路径必须事先不存在**,不然会报错。'10'是batch number,和batch size相乘须为提取特征的总图片数,不然好像会重复之前的图片还是怎样,有博主写到,没有试过。batch size不能设置过大,不然会out of memory,这个根据自己的机子调整就ok。'imagenet_val.prototxt'是提特征的网络,SVM需要训练集和测试集的特征,所以我们这里要分开求,要用两个网络。
提取lmdb类型的会得到这样的文件夹
里面是这样的
这一步就完成了。
将提出的特征转换成Matlab的.mat格式
这边是借鉴了这篇博客http://m.blog.csdn.net/article/details?id=48180331的方法,十分感谢这位博主!
参考 http://www.cnblogs.com/platero/p/3967208.html 和 lmdb的文档https://lmdb.readthedocs.org/en/release,读取lmdb文件,然后转换成mat文件,再用matlab调用mat进行可视化。安装CAFFE的python依赖库,并使用以下两个辅助文件把lmdb转换为mat。
./feat_helper_pb2.py
# Generated by the protocol buffer compiler. DO NOT EDIT!from google.protobuf import descriptorfrom google.protobuf import messagefrom google.protobuf import reflectionfrom google.protobuf import descriptor_pb2# @@protoc_insertion_point(imports)DESCRIPTOR = descriptor.FileDescriptor( name='datum.proto', package='feat_extract', serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02')_DATUM = descriptor.Descriptor( name='Datum', full_name='feat_extract.Datum', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ descriptor.FieldDescriptor( name='channels', full_name='feat_extract.Datum.channels', index=0, number=1, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='height', full_name='feat_extract.Datum.height', index=1, number=2, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='width', full_name='feat_extract.Datum.width', index=2, number=3, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='data', full_name='feat_extract.Datum.data', index=3, number=4, type=12, cpp_type=9, label=1, has_default_value=False, default_value="", message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='label', full_name='feat_extract.Datum.label', index=4, number=5, type=5, cpp_type=1, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), descriptor.FieldDescriptor( name='float_data', full_name='feat_extract.Datum.float_data', index=5, number=6, type=2, cpp_type=6, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, extension_ranges=[], serialized_start=29, serialized_end=134,)DESCRIPTOR.message_types_by_name['Datum'] = _DATUMclass Datum(message.Message): __metaclass__ = reflection.GeneratedProtocolMessageType DESCRIPTOR = _DATUM # @@protoc_insertion_point(class_scope:feat_extract.Datum)# @@protoc_insertion_point(module_scope)
./lmdb2mat.py
import lmdbimport feat_helper_pb2import numpy as npimport scipy.io as sioimport timedef main(argv): lmdb_name = sys.argv[1] print "%s" % sys.argv[1] batch_num = int(sys.argv[2]); batch_size = int(sys.argv[3]); window_num = batch_num*batch_size; start = time.time() if 'db' not in locals().keys(): db = lmdb.open(lmdb_name) txn= db.begin() cursor = txn.cursor() cursor.iternext() datum = feat_helper_pb2.Datum() keys = [] values = [] for key, value in enumerate( cursor.iternext_nodup()): keys.append(key) values.append(cursor.value()) ft = np.zeros((window_num, int(sys.argv[4]))) for im_idx in range(window_num): datum.ParseFromString(values[im_idx]) ft[im_idx, :] = datum.float_data print 'time 1: %f' %(time.time() - start) sio.savemat(sys.argv[5], {'feats':ft}) print 'time 2: %f' %(time.time() - start) print 'done!'if __name__ == '__main__': import sys main(sys.argv)
前面两个文档都不用改,直接贴到.py文件里,然后运行如下.sh就行了。
#!/usr/bin/env shLMDB=./examples/_temp/features_fc7 # lmdb文件路径BATCHNUM=1BATCHSIZE=10# DIM=290400 # feature长度,conv1# DIM=43264 # conv5DIM=4096OUT=./examples/_temp/features_fc7.mat #mat文件保存路径python ./lmdb2mat.py $LMDB $BATCHNUM $BATCHSIZE $DIM $OUT
‘BATCHNUM’和‘BATCHSIZE’就是提特征的时候的batch的数目和大小。
‘DIM’是特征的维度,这个要自己计算,可以用命令查看某一层网络数据参数,然后第一位是batchsize,把剩下几位(一般还剩一位或三位)乘起来就ok了。
用SVM(LibSVM)做分类
终于来到了用SVM做分类,不过时间有限我还是没能学会用SVM做multilabel分类,于是我就只能分别对每个label分类求精度然后取平均值了,这样可能不太科学。 NUS-WIDE是81个concept所以算81次精度,代码贴在下面。
clear clcaddpath D:\dpTask\NUS-WIDE\NUS-WIDE-LiteTrainLabels=importdata('TrainLabels_Lite.mat');TestLabels=importdata('TestLabels_Lite.mat');trfeatures = importdata('train.mat');trfeatures_sparse = sparse(trfeatures); % features must be in a sparse matrixtefeatures = importdata('test.mat');tefeatures_sparse = sparse(tefeatures); % features must be in a sparse matrixfor la=1:81 %获取libsvm格式数据 fprintf('iter=%d,processing data ...\n',la); tic; trlabel = TrainLabels(:,la); telabel = TestLabels(:,la); libsvmwrite('SVMtrain.txt', trlabel, trfeatures_sparse); libsvmwrite('SVMtest.txt', telabel, tefeatures_sparse); toc; %用libsvm进行训练 fprintf('iter=%d,libsvm training ...\n',la); tic; [train_label,train_feature] = libsvmread('SVMtrain.txt'); model = svmtrain(train_label,train_feature,'-h 0'); fprintf('Model got!\n'); %model = svmtrain(train_label,train_feature,'-c 2.0 -g 0.00048828125');这个是带参数的 [test_label,test_feature] = libsvmread('SVMtest.txt'); [predict_label,accur,score] = svmpredict(test_label,test_feature,model); fprintf('Prediction Done!\n'); toc; label(:,la) = predict_label; accuracy(la) = accur(1); fprintf('iter=%d,acurracy=%d\n',la,accur(1)); name=sprintf('result/accuracy.txt'); fid = fopen(name,'at'); fprintf(fid,'No.%d\t%f\n',la,accur(1)); fclose(fid);end
直接用的默认参数没有寻优,有需要的小伙伴可以grid.py寻优一下,或者直接easy.py跑整个实验。
可以参考的资料:
[1] http://m.blog.csdn.net/article/details?id=48180331
[2] http://www.cnblogs.com/jeffwilson/p/5122495.html?utm_source=itdadao&utm_medium=referral
这个是用matlab接口提特征的
[3] http://www.cnblogs.com/denny402/p/5686257.html
这个里面有查看各层参数的代码
- 深度学习之Caffe(一) 用c++接口提取特征后用SVM分类
- 用Caffe提取深度特征
- 深度学习Caffe实战笔记(10)Windows Caffe使用MATLAB接口提取和可视化特征
- 关于caffe的序列 :用Caffe提取深度特征
- 深度学习之-caffe预测、特征可视化python接口调用 (6)
- 毕业设计(一)——基于深度学习的一类图像共性特征提取 (caffe)
- caffe学习笔记之特征提取(win10)
- 提取HOG特征训练SVM分类器(一)HOG篇
- Deep Learning(深度学习) caffe模型 特征提取 (windows/linux)
- Caffe图片特征提取(Python/C++)
- caffe提取特征用svm进行分类
- 深度学习之---Caffe(一)
- caffe c++API特征提取
- caffe的python接口学习(11):特征的批量提取
- caffe 提取可视化特征遇到keyerror(即用matlab显示提取特征)
- caffe初探之-特征提取
- 深度学习(九)caffe预测、特征可视化python接口调用
- 深度学习(九)caffe预测、特征可视化python接口调用
- linux下给目录下所有子目录和文件赋权
- 【js设计模式笔记---装饰者模式】
- Linux 如何判断哪个网卡是否连接网线
- 写一个读取环境变量的Express中间件
- 苹果充值常见的刷单手段和防范方法
- 深度学习之Caffe(一) 用c++接口提取特征后用SVM分类
- Linux无法挂在ntfs格式设备
- <HeadFirst_HTML与CSS> O'REILLY_Chap.2_认识HTML中的"HT"
- linux虚拟机加载U盘
- ruhe解决秒杀的性能问题和超卖的讨论
- windows根据网络时间同步
- request.getParameter的值为空 分析
- 【js设计模式笔记---享元模式】
- git如何添加空文件夹