caffe-multilabel classification
来源:互联网 发布:万方经济统计数据库 编辑:程序博客网 时间:2024/05/16 23:53
Multilabel classification on PASCAL using python data-layers
用PASCAL VOC 2012做多类别分类
多类别分类即每个图像或者实例属于多个类别。而多分类中每个图像或者实例只有一个标签。
caffe是在SigmoidCrossEntropyLoss层实现的多标签分类。这里用python data 层加载数据,数据也可以是HDF5,LMDB。但是python data层比较灵活,所以这里使用python data 层。
1.准备
1. 首先确认编译caffe的时候使用WITH_PYTHON_LAYER := 1
2. 下载PASCAL VOC 2012数据http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html
3. 加载模型
import sys import osimport numpy as npimport os.path as ospimport matplotlib.pyplot as pltfrom copy import copy% matplotlib inlineplt.rcParams['figure.figsize'] = (6, 6)caffe_root = '../' # this file is expected to be in {caffe_root}/examplessys.path.append(caffe_root + 'python')import caffe # If you get "No module named _caffe", either you have not built pycaffe or you have the wrong path.from caffe import layers as L, params as P # Shortcuts to define the net prototxt.sys.path.append("pycaffe/layers") # the datalayers we will use are in this directory.sys.path.append("pycaffe") # the tools file is in this folderimport tools #this contains some tools that we need
- 设置数据路径初始化caffe
# set data root directory, e.g:pascal_root = osp.join(caffe_root, 'data/pascal/VOC2012')# these are the PASCAL classes, we'll need them later.classes = np.asarray(['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])# make sure we have the caffenet weight downloaded.if not os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'): print("Downloading pre-trained CaffeNet model...") !../scripts/download_model_binary.py ../models/bvlc_reference_caffenet# initialize caffe for gpu modecaffe.set_mode_gpu()caffe.set_device(0)
2 定义网络 prototxt
使用caffe.NetSpec定义网络,要注意的事情是怎么使用SigmoidCrossEntropyLoss 层,和数据层的定义。
# helper function for common structuresdef conv_relu(bottom, ks, nout, stride=1, pad=0, group=1): conv = L.Convolution(bottom, kernel_size=ks, stride=stride, num_output=nout, pad=pad, group=group) return conv, L.ReLU(conv, in_place=True)# another helper functiondef fc_relu(bottom, nout): fc = L.InnerProduct(bottom, num_output=nout) return fc, L.ReLU(fc, in_place=True)# yet another helper functiondef max_pool(bottom, ks, stride=1): return L.Pooling(bottom, pool=P.Pooling.MAX, kernel_size=ks, stride=stride)# main netspec wrapperdef caffenet_multilabel(data_layer_params, datalayer): # setup the python data layer n = caffe.NetSpec() n.data, n.label = L.Python(module = 'pascal_multilabel_datalayers', layer = datalayer, ntop = 2, param_str=str(data_layer_params)) # the net itself n.conv1, n.relu1 = conv_relu(n.data, 11, 96, stride=4) n.pool1 = max_pool(n.relu1, 3, stride=2) n.norm1 = L.LRN(n.pool1, local_size=5, alpha=1e-4, beta=0.75) n.conv2, n.relu2 = conv_relu(n.norm1, 5, 256, pad=2, group=2) n.pool2 = max_pool(n.relu2, 3, stride=2) n.norm2 = L.LRN(n.pool2, local_size=5, alpha=1e-4, beta=0.75) n.conv3, n.relu3 = conv_relu(n.norm2, 3, 384, pad=1) n.conv4, n.relu4 = conv_relu(n.relu3, 3, 384, pad=1, group=2) n.conv5, n.relu5 = conv_relu(n.relu4, 3, 256, pad=1, group=2) n.pool5 = max_pool(n.relu5, 3, stride=2) n.fc6, n.relu6 = fc_relu(n.pool5, 4096) n.drop6 = L.Dropout(n.relu6, in_place=True) n.fc7, n.relu7 = fc_relu(n.drop6, 4096) n.drop7 = L.Dropout(n.relu7, in_place=True) n.score = L.InnerProduct(n.drop7, num_output=20) n.loss = L.SigmoidCrossEntropyLoss(n.score, n.label) return str(n.to_proto())
3 写网络和solver 文件
solver文件,我们使用CaffeSolver类,和tools 模块
workdir = './pascal_multilabel_with_datalayer'if not os.path.isdir(workdir): os.makedirs(workdir)solverprototxt = tools.CaffeSolver(trainnet_prototxt_path = osp.join(workdir, "trainnet.prototxt"), testnet_prototxt_path = osp.join(workdir, "valnet.prototxt"))solverprototxt.sp['display'] = "1"solverprototxt.sp['base_lr'] = "0.0001"solverprototxt.write(osp.join(workdir, 'solver.prototxt'))# write train net.with open(osp.join(workdir, 'trainnet.prototxt'), 'w') as f: # provide parameters to the data layer as a python dictionary. Easy as pie! data_layer_params = dict(batch_size = 128, im_shape = [227, 227], split = 'train', pascal_root = pascal_root) f.write(caffenet_multilabel(data_layer_params, 'PascalMultilabelDataLayerSync'))# write validation net.with open(osp.join(workdir, 'valnet.prototxt'), 'w') as f: data_layer_params = dict(batch_size = 128, im_shape = [227, 227], split = 'val', pascal_root = pascal_root) f.write(caffenet_multilabel(data_layer_params, 'PascalMultilabelDataLayerSync'))
PascalMultilabelDataLayerSync这里使用了paython data 层,定义在’./pycaffe/layers/pascal_multilabel_datalayers.py’
接着加载caffe solver
solver = caffe.SGDSolver(osp.join(workdir, 'solver.prototxt'))solver.net.copy_from(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel')solver.test_nets[0].share_with(solver.net)solver.step(1)
BatchLoader initialized with 5717 imagesPascalMultilabelDataLayerSync initialized for split: train, with bs: 128, im_shape: [227, 227].BatchLoader initialized with 5823 imagesPascalMultilabelDataLayerSync initialized for split: val, with bs: 128, im_shape: [227, 227].
transformer = tools.SimpleTransformer() # This is simply to add back the bias, re-shuffle the color channels to RGB, and so on...image_index = 0 # First image in the batch.plt.figure()plt.imshow(transformer.deprocess(copy(solver.net.blobs['data'].data[image_index, ...])))gtlist = solver.net.blobs['label'].data[image_index, ...].astype(np.int)plt.title('GT: {}'.format(classes[np.where(gtlist)]))plt.axis('off');
4 训练网络
用汉明距离是多标签问题中常用的计算准确度
def hamming_distance(gt, est): return sum([1 for (g, e) in zip(gt, est) if g == e]) / float(len(gt))def check_accuracy(net, num_batches, batch_size = 128): acc = 0.0 for t in range(num_batches): net.forward() gts = net.blobs['label'].data ests = net.blobs['score'].data > 0 for gt, est in zip(gts, ests): #for each ground truth and estimated label vector acc += hamming_distance(gt, est) return acc / (num_batches * batch_size)
for itt in range(6): solver.step(100) print 'itt:{:3d}'.format((itt + 1) * 100), 'accuracy:{0:.4f}'.format(check_accuracy(solver.test_nets[0], 50))
准确度下降,一般之显示1到2个标签。
def check_baseline_accuracy(net, num_batches, batch_size = 128): acc = 0.0 for t in range(num_batches): net.forward() gts = net.blobs['label'].data ests = np.zeros((batch_size, len(gts))) for gt, est in zip(gts, ests): #for each ground truth and estimated label vector acc += hamming_distance(gt, est) return acc / (num_batches * batch_size)print 'Baseline accuracy:{0:.4f}'.format(check_baseline_accuracy(solver.test_nets[0], 5823/128))
6 查看预测结果
test_net = solver.test_nets[0]for image_index in range(5): plt.figure() plt.imshow(transformer.deprocess(copy(test_net.blobs['data'].data[image_index, ...]))) gtlist = test_net.blobs['label'].data[image_index, ...].astype(np.int) estlist = test_net.blobs['score'].data[image_index, ...] > 0 plt.title('GT: {} \n EST: {}'.format(classes[np.where(gtlist)], classes[np.where(estlist)])) plt.axis('off')
阅读全文
0 0
- caffe-multilabel classification
- caffe︱Pascal VOC 2012 Multilabel Classification Model
- caffe 实例笔记 4 Multilabel classification on PASCAL using python data-layers
- caffe学习笔记之Multilabel classification on PASCAL using python data-layers
- 多标签分类(multilabel classification )
- 多标签分类(multilabel classification )
- 多标签分类(multilabel classification )
- 多标签分类(multilabel classification )
- caffe实现多标签输入(multilabel、multitask)
- caffe实现多标签输入(multilabel、multitask)
- caffe实现多标签输入(multilabel、multitask)
- caffe实现多标签输入(multilabel、multitask)
- 【笔记】A Generic Multilabel Learning-Based Classification Algorithm Recommendation Method
- caffe introduction & classification
- Caffe之Classification
- caffe-Python-classification-01
- caffe in python ---Classification
- caffe 实现多标签输入(multilabel、multitask)
- 一些面试题整理(java)
- 使用openssl加密一份文件
- Java 中常见的异常及处理
- git指令-Github使用记录
- Ubuntu搭建wordpress开发环境记录
- caffe-multilabel classification
- Lesson 2:作业2:打印当前时间。学习使用Date类和Calendar类。
- 使用Python统计字符串中单词数量
- Fragment中getActivity()和getContext()为空的问题
- Unity3D接入微信登录SDK安卓版
- bash之命令替换(command substitution)
- STM32关于RTC的一些有意思的发现
- I2C总线时序解析
- Java学习笔记—使用dbcp2数据库连接池方式连接测试数据库