使用caffe训练一个多标签分类/回归模型

来源:互联网 发布:剑雨江湖数据晋级 编辑:程序博客网 时间:2024/06/06 00:54

深度学习交流QQ群:116270156


前言

这篇博客和上一篇性质差不多,都是旨在说明使用caffe训练图像分类模型的大概流程。不同的是,上篇博客讲的单标签图像分类问题,顾名思义,其

输入和输出都是单标签或者可以说是单类别的,而此篇则把重点放在如何处理多标签分类/回归问题的输入和输出上。多标签分类/回归问题和单标签的工作流程比较类似,大致分为以下几个步骤,然后本博客再对各个环节做进一步解释。



准备数据(图像整理好放到合适的文件夹中,对应的ground-truth整理到一个txt中)。
编写用于模型输入的代码(使用python data layer来读取图片和label)。
网络模型的定义/编辑(从头编写或修改prototxt文件)。
Solver的定义/编辑(solver是caffe用来训练网络的一个类似配置文件的东西)。
进行网络模型的训练与测试。
可以看出,此博客和上一篇在流程上的区别主要集中在第二步。博客写作时,caffe还不支持将含有多标签ground-truth的txt文件直接转换成LMDB文件。虽然仍然可以使用LMDB的格式,但是需要分别将图片和label放到两个LMDB文件中,操作略复杂,而caffe官网上比较推荐的一种方式是使用python data layer来直接读取图片和label,故本文也采用了这种形式。



准备数据


这一步先按一定的比例、一定的划分策略将数据集划分为训练集、验证集(根据具体情况可有可无)和测试集,如果要在使用caffe前手动地进行数据预处理或数据增强,尽量在数据集划分之前完成。数据划分好之后,对应的还要有标有ground-truth的txt文件,大致如下图所示:




文件的每一行代表一个样本,一行中第一项是图片的名称(名称前还可以有路径),后面是类别标签,具体的任务包含几类图像就对应着有多少个标签,每个标签都是非0即1的,0代表当前图片不属于这个标签,1则相反。图片名称以及各个标签之间以空格间隔。


编写用于模型输入的代码


这一步是这篇博客的重点。caffe自带的例子中也包含了使用python data layer来处理多标签输入的情况。我们更改一下例子给出的源文件大致即可实现自己的需求。首先,找到”caffe_root”/examples/pycaffe/layers/pascal_multilabel_datalayers.py,这个文件就是用于处理多标签输入的代码。此例子是关于PASCAL VOC 2012数据集上的多标签分类任务的,样本情况大致如下所示: 



对应的ground-truth里除了有物体类别标签,还有各物体在图片中对应的bounding box,如果读者的任务和该例子相似,可以仔细研究下这份代码。本人的工作虽然也是处理多标签问题,但是是研究天气识别的,和这个例子略有不同,ground-truth中只有类别标签而已,所以对代码做了一定的修改,其实是变简单了。下面来看一下代码,先贴出全部代码,然后再主要讲一些需要修改的部分。


[python] view plain copy
  1. import scipy.misc  
  2. import caffe  
  3.   
  4. import numpy as np  
  5. import os.path as osp  
  6.   
  7. from random import shuffle  
  8. from PIL import Image  
  9.   
  10. from tools import SimpleTransformer  
  11.   
  12.   
  13. class WeatherMultilabelDataLayerSync(caffe.Layer):  
  14.   
  15.     """ 
  16.     This is a simple syncronous datalayer for training a multilabel model on 
  17.     weather dataset. 
  18.     """  
  19.   
  20.     def setup(self, bottom, top):  
  21.   
  22.         self.top_names = ['data''label']  
  23.   
  24.         # === Read input parameters ===  
  25.   
  26.         # params is a python dictionary with layer parameters.  
  27.         params = eval(self.param_str)  
  28.   
  29.         # Check the paramameters for validity.  
  30.         check_params(params)  
  31.   
  32.         # store input as class variables  
  33.         self.batch_size = params['batch_size']  
  34.   
  35.         # Create a batch loader to load the images.  
  36.         self.batch_loader = BatchLoader(params, None)  
  37.   
  38.         # === reshape tops ===  
  39.         # since we use a fixed input image size, we can shape the data layer  
  40.         # once. Else, we'd have to do it in the reshape call.  
  41.         top[0].reshape(  
  42.             self.batch_size, 3, params['im_shape'][0], params['im_shape'][1])  
  43.         top[1].reshape(self.batch_size, 7)  
  44.   
  45.         print_info("WeatherMultilabelDataLayerSync", params)  
  46.   
  47.     def forward(self, bottom, top):  
  48.         """ 
  49.         Load data. 
  50.         """  
  51.         for itt in range(self.batch_size):  
  52.             # Use the batch loader to load the next image.  
  53.             im, multilabel = self.batch_loader.load_next_image()  
  54.             # Add directly to the caffe data layer  
  55.             top[0].data[itt, ...] = im  
  56.             top[1].data[itt, ...] = multilabel  
  57.   
  58.     def reshape(self, bottom, top):  
  59.         """ 
  60.         There is no need to reshape the data, since the input is of fixed size 
  61.         (rows and columns) 
  62.         """  
  63.         pass  
  64.   
  65.     def backward(self, top, propagate_down, bottom):  
  66.         """ 
  67.         These layers does not back propagate 
  68.         """  
  69.         pass  
  70.   
  71.   
  72. class BatchLoader(object):  
  73.   
  74.     """ 
  75.     This class abstracts away the loading of images. 
  76.     Images can either be loaded singly, or in a batch. The latter is used for 
  77.     the asyncronous data layer to preload batches while other processing is 
  78.     performed. 
  79.     """  
  80.   
  81.     def __init__(self, params, result):  
  82.         self.result = result  
  83.         self.batch_size = params['batch_size']  
  84.         self.weather_root = params['weather_root']  
  85.         self.im_shape = params['im_shape']  
  86.         if params['split'] == 'train':  
  87.             self.isshuffle = True  
  88.             self.iscenter = False  
  89.             self.isflip = True  
  90.         else:  
  91.             self.isshuffle = False  
  92.             self.iscenter = True  
  93.             self.isflip = False  
  94.         # get list of image indexes.  
  95.         self.floder = 'weather_' + params['split'] + '/'  #data folder   
  96.         list_file = params['split'] + '.txt'  
  97.         self.indexlist = [line.rstrip('\n'for line in open(  
  98.             osp.join(self.weather_root, list_file))]  
  99.             #osp.join(self.weather_root, 'MultiLabel/', list_file))]  
  100.         self._cur = 0  # current image  
  101.         # this class does some simple data-manipulations  
  102.         self.transformer = SimpleTransformer(center=self.iscenter)  
  103.         # At the beginning, shuffle the train data  
  104.         if self.isshuffle:  
  105.             shuffle(self.indexlist)  
  106.   
  107.         print "BatchLoader initialized with {} images".format(  
  108.             len(self.indexlist))  
  109.   
  110.     def load_next_image(self):  
  111.         """ 
  112.         Load the next image in a batch. 
  113.         """  
  114.         # Did we finish an epoch?  
  115.         if self._cur == len(self.indexlist):  
  116.             self._cur = 0  
  117.             if self.isshuffle:  
  118.                 shuffle(self.indexlist)  
  119.   
  120.         # Load an image  
  121.         rowdata = self.indexlist[self._cur]  # Get the rowdata  
  122.         rowdata = rowdata.split()  
  123.         image_file_name = rowdata[0]  
  124.         im = np.asarray(Image.open(  
  125.             osp.join(self.weather_root, self.floder, image_file_name)))  
  126.         im = scipy.misc.imresize(im, [256256])  # resize  
  127.   
  128.         # do a simple horizontal flip as data augmentation  
  129.         if self.isflip:  
  130.             flip = np.random.choice(2)*2-1  
  131.             im = im[:, ::flip, :]  
  132.   
  133.         # Load and prepare ground truth  
  134.         multilabel = np.zeros(7).astype(np.float64)  
  135.         for j in range(7):  
  136.             multilabel[j] = np.float64(rowdata[j+1])  
  137.   
  138.         self._cur += 1  
  139.         return self.transformer.preprocess(im), multilabel  
  140.   
  141.   
  142.   
  143. def check_params(params):  
  144.     """ 
  145.     A utility function to check the parameters for the data layers. 
  146.     """  
  147.     assert 'split' in params.keys(  
  148.     ), 'Params must include split (train, val, or test).'  
  149.   
  150.     required = ['batch_size''weather_root''im_shape']  
  151.     for r in required:  
  152.         assert r in params.keys(), 'Params must include {}'.format(r)  
  153.   
  154.   
  155. def print_info(name, params):  
  156.     """ 
  157.     Ouput some info regarding the class 
  158.     """  
  159.     print "{} initialized for split: {}, with bs: {}, im_shape: {}.".format(  
  160.         name,  
  161.         params['split'],  
  162.         params['batch_size'],  
  163.         params['im_shape'])  


代码修改


在setup函数中,有一个地方需要注意:

[python] view plain copy
  1. top[0].reshape(self.batch_size, 3, params['im_shape'][0],params['im_shape'][1])  
  2. # Note the 20 channels (because PASCAL has 20 classes.)  
  3. top[1].reshape(self.batch_size, 20)  

这是例子中的代码,top[0]中存储的是batch_size幅图像,而top[1]中则存储对应的标签。注释中也说了,因为这个PASCAL数据集中包含20类,所以就有了第二句代码的由来。而我的任务中有7类,所以top[1]这个地方就要改一改。

[python] view plain copy
  1. top[1].reshape(self.batch_size, 7)  


另外,紧随top[1]这句代码之后的是一句提示信息,最好也改成和自己任务对应的输出:

[python] view plain copy
  1. print_info("PascalMultilabelDataLayerSync", params)  


[python] view plain copy
  1. print_info("WeatherMultilabelDataLayerSync", params)  

接下来,forward函数不用修改,而reshape与backward函数本身就没有函数体,也不用处理。


然后,BatchLoader这个类中有一些需要更改的地方,首先在 __init__ 函数中,数据的根目录变量最好改一下:

[python] view plain copy
  1. self.weather_root = params['weather_root']  

这个变量和字符串虽不代表真正的路径,但还是改为和自己任务相关的名称比较好。至于具体的路径是在网络结构定义的prototxt文件中指定的,这个在下一个大步骤中还会再次提到。


往下看是本人为了达到一定目的而加的一小段代码:


[python] view plain copy
  1. if params['split'] == 'train':  
  2.     self.isshuffle = True  
  3.     self.iscenter = False  
  4.     self.isflip = True  
  5. else:  
  6.     self.isshuffle = False  
  7.     self.iscenter = True  
  8.     self.isflip = False  

由于我的数据之前并没有手动地进行shuffle,所以这里针对训练集要处理一下。另外,为了进行一定的数据增强,如果是训练集,就进行随机地裁剪而不是裁图片的正中央,并且也会实施随机地镜像翻转。而对于测试集和验证集来说,这些条目则恰好相反。代码中变量的名字和意义都非常好理解,就不再赘述了。


我的训练、验证和测试用的数据是用三个文件夹分别存放的,所以又加了一个变量来准确定位到不同的数据:

self.floder = 'weather_' + params['split'] + '/' #data folder


再往下是打开存有label信息的txt文件的代码,这里的路径需要改一下:

[python] view plain copy
  1. self.indexlist = [line.rstrip('\n'for line in     open(osp.join(self.weather_root, list_file))]  

注意改成自己的目录变量名。

[python] view plain copy
  1. self.transformer = SimpleTransformer(center=self.iscenter)  

这句是创建一个用于图像变换的对象,而对应的类是在其他文件中定义的,为了实现自己的图像变换,我对这个类也做了修改,具体内容一会就会讲到。


前面提到要对数据进行shuffle,以下两句就是具体代码,非常简单,调一下shuffle函数就OK了。

[python] view plain copy
  1. if self.isshuffle:  
  2.     shuffle(self.indexlist)  

再来看load_next_iamge函数,一上来先是对遍历数据时索引的处理:


[python] view plain copy
  1. if self._cur == len(self.indexlist):  
  2.     self._cur = 0  
  3.     if self.isshuffle:  
  4.         shuffle(self.indexlist)  

当索引已经到了txt文件的末尾时,则把它重新置零,以便进行新一轮的迭代。除此之外,还要再做一步操作,就是对训练集进行了再一次的shuffle,避免每次迭代都是一样的数据顺序。


接着是实际读取数据的代码:

[python] view plain copy
  1. rowdata = self.indexlist[self._cur]  # Get the rowdata  
  2. rowdata = rowdata.split()  
  3. image_file_name = rowdata[0]  
  4. im = np.asarray(Image.open(  
  5.     osp.join(self.weather_root, self.floder, image_file_name)))  
  6. im = scipy.misc.imresize(im, [256256])  

先是读取txt文件的一行,然后就按空格对读取的数据进行分割。分割后的第一个数据就是图像名字,接着结合数据的根目录以及数据集的文件夹把图像读到Numpy的一个array中。由于图片大小不一,还要对所有图片统一先resize一下,resize的大小是本工作所用到的AlexNet约定俗成的参数。


然后就是为了数据增强而进行的随机翻转:


[python] view plain copy
  1. if self.isflip:  
  2.     flip = np.random.choice(2)*2-1  
  3.     im = im[:, ::flip, :]  

最后是读取label的代码,顺便对调用__init__ 函数中创建的transformer对象对图像进行一定的变换:

[python] view plain copy
  1. multilabel = np.zeros(7).astype(np.float64)  
  2. multilabel[:] = rowdata[1:]  
  3. self._cur += 1  
  4. return self.transformer.preprocess(im), multilabel  




该python文件中最后两个check_params和print_info函数都不用做任何修改。


对图像转换函数的修改


之后再来看一下定义图像变换函数的相关代码,找到”caffe_root”/examples/pycaffe/tools.py文件,同样先贴出全部代码,其实该文件中还有一个用于定义solver的类,这里就不讲了,solver文件可以直接对文件修改,没必要代码生成。

[python] view plain copy
  1. import numpy as np  
  2.   
  3. class SimpleTransformer:  
  4.   
  5.     """ 
  6.     SimpleTransformer is a simple class for preprocessing and deprocessing 
  7.     images for caffe. 
  8.     """  
  9.   
  10.     def __init__(self, center=False):                                                                                                                                                                                                                                                       
  11.         mean=                                                                                                     
  12.         np.load('/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Data/  
  13.         mean.npy')  
  14.         center_mean =       
  15.         np.load('/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Data/  
  16.         center_mean.npy')  
  17.         self.mean = mean.transpose((1,2,0))  
  18.         self.center_mean = center_mean.transpose((1,2,0))  
  19.         self.scale = 1.0  
  20.         self.center=center  
  21.   
  22.     def set_mean(self, mean):  
  23.         """ 
  24.         Set the mean to subtract for centering the data. 
  25.         """  
  26.         self.mean = mean  
  27.   
  28.     def set_scale(self, scale):  
  29.         """ 
  30.         Set the data scaling. 
  31.         """  
  32.         self.scale = scale  
  33.   
  34.     def crop(self, im, cropx=227, cropy=227):  
  35.         ## the image has the HWC channel order  
  36.         y,x,_ = im.shape  
  37.   
  38.         if self.center:  
  39.             startx = x//2-(cropx//2)  
  40.             starty = y//2-(cropy//2)      
  41.         else:  
  42.             startx, starty = np.random.randint(0,29), np.random.randint(0,29)  
  43.   
  44.         return im[starty:starty+cropy,startx:startx+cropx,:]   
  45.   
  46.   
  47.   
  48.     def preprocess(self, im):  
  49.         """ 
  50.         preprocess() emulate the pre-processing occuring in the vgg16 caffe 
  51.         prototxt. 
  52.         """  
  53.   
  54.         im = np.float64(im)  
  55.         im = im[:, :, ::-1]  # change to BGR  
  56.         im -= self.mean  
  57.         im = self.crop(im)  
  58.         im *= self.scale  
  59.         im = im.transpose((201))  
  60.   
  61.         return im  
  62.   
  63.     def deprocess(self, im):  
  64.         """ 
  65.         inverse of preprocess() 
  66.         """  
  67.         im = im.transpose(120)  
  68.         im /= self.scale  
  69.         im += self.center_mean  
  70.         im = im[:, :, ::-1]  # change to RGB  
  71.   
  72.         return np.uint8(im)  


使用caffe训练网络时,数据的格式跟普通RGB图像的通道顺序不一样,所以需要我们手动地更改一下。而在上一篇讲单标签分类任务的博客中,数据是用caffe提供的工具转换成了LMDB文件,其中就包含了这些操作。在tools.py文件中,最主要的是preprocess函数,它包含了对图像所有的变换:

[python] view plain copy
  1. im = im[:, :, ::-1]  # change to BGR  
  2. im -= self.mean  
  3. im = self.crop(im)  
  4. im *= self.scale  
  5. im = im.transpose((201))  


具体的,先转换成BGR的通道顺序,再减去训练集上计算出的均值,接着对图像进行裁剪,得到227*227的用于AlexNet输入的图片大小。乘scale这一句在本工作中其实没用到,因为scale的大小设为了1,可能对其他一些任务,scale可以设成不同的值。最后是对图像进行一下转置变换,也是为了应对caffe的要求。其中的均值mean是在__init__ 函数中设定好的:


[python] view plain copy
  1. mean = np.load('/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Data/mean.npy')  


需要注意的是,上篇博客中,利用caffe提供的脚本生成的均值文件是mean.binaryproto文件,这里要用另一段代码将其转换成npy文件,以便进行Numpy数组间的计算,转换代码如下:


[python] view plain copy
  1. def convert(binaryproto='/home/wzg/caffe-master/data/weather/weather_mean.binaryproto',savepath='/home/wzg/caffe-master/data/weather/mean'):  
  2.     blob = caffe.proto.caffe_pb2.BlobProto()  
  3.     data = open(binaryproto, 'rb' ).read()  
  4.     blob.ParseFromString(data)  
  5.     arr = np.array(caffe.io.blobproto_to_array(blob))  
  6.     out = arr[0]  
  7.     np.save(savepath, out)  


其实也是利用了caffe提供的工具对mean.binaryproto文件进行了解析并将其转换成Numpy的array。


对图像的裁剪定义在crop函数中,操作还是比较简单的,就不详细解释了,只是相对于源代码来说,我特地区分了训练集和测试集的裁剪方式,一个是在图像内随机裁剪,一个是在中心裁剪。


在__init__ 函数中,除了读取mean.npy文件,我还读取了一个center_mean.npy文件,这个center_mean是为在deprocess函数中做图像反变换而生成的。反变换时,图像要加均值,但是网络中的数据已经是227*227大小了(反变换并不会把图像再变回256*256,因为网络中输入的就是227*227的图像),所以对应的均值矩阵的长宽也应该是227*227,这就需要对生成的mean.npy用crop函数做一下中心裁剪(由于训练集是随机裁剪的,并不能恢复回来,这里只对测试数据进行反变换),然后保存起来。另外读取均值文件后初始化时也要对其转置一下。


deprocess函数是为了在测试时方便显示出测试图像和对应的识别结果,达到一种直观的目的。如不需要这样的效果,完全可以不用定义这个函数,也不用再额外生成center_mean文件。


到此为止,关于网络输入的代码部分就讲完了,然后就是对网络结构定义文件的修改了。



网络模型的定义/编辑



相比于单标签分类任务,多标签的网络结构定义文件主要有两个地方需要更改,一个就是数据输入层:

[python] view plain copy
  1. layer {  
  2.   name: "data"  
  3.   type: "Python"  
  4.   top: "data"  
  5.   top: "label"  
  6.   python_param {  
  7.     module: "weather_multilabel_datalayers"  
  8.     layer: "WeatherMultilabelDataLayerSync"  
  9.     param_str: "{\'im_shape\': [227, 227], \'split\': \'train\',   
  10.     \'batch_size\': 50, \'weather_root\':   
  11.     \'/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Data/\'}"  
  12.   }  
  13. }  


使用python data layer读取数据的话,相较于LMDB文件,网络定义的数据层没有了transform_param和data_param,取而代之的是python_param。python_param中,module项对应的是处理多标签输入的文件的名字(这里我的文件名为weather_multilabel_datalayers.py),layer项对应的是文件中的类名(class WeatherMultilabelDataLayerSync(caffe.Layer)),可以看到它是继承了caffe.Layer这个基类的。要使用这个文件以及之前提到的用于图像变换的tools.py文件还需要在自己用于训练网络的python文件中把相关路径加到系统路径中去:


[python] view plain copy
  1. import sys  
  2. sys.path.append(caffe_root+"examples/pycaffe")  
  3. sys.path.append(caffe_root+"examples/pycaffe/layers")   
  4. import tools  


数据层中接下来的param_str中定义了一些参数,这些参数会传给WeatherMultilabelDataLayerSync做处理,参数内容都很好理解。其中split就是为区分不同数据集加的一个变量,weather_root是数据存放的根目录,其下有含label信息的txt文件和装有图片数据的文件夹,情况如下: 



这里写图片描述。读者完全可以根据自己的需求组织文件结构,不过在处理输入的文件中就需要做一定的更改了。另外,在数据层的定义中要特别注意其中含有很多反斜杠,应该是为了转义吧。


除了手动更改定义网络的prototxt文件,也可以写代码生成此文件,这里就不讲了,感兴趣者可以参考caffe官网给出的关于多标签分类的例子。




第二个要改的地方是loss层,单标签分类任务中,比较常用的loss函数是softmax loss function,而softmax loss的目的是从多类中选一类,正好适合单标签问题。对于多标签分类任务来说,SigmoidCrossEntropyLoss是比较合适的,因为它能独立地对多个标签进行预测,并且这个loss函数也是caffe已经实现好的。网络定义中的loss层如下:

[python] view plain copy
  1. layer {  
  2.   name: "loss"  
  3.   type: "SigmoidCrossEntropyLoss"  
  4.   bottom: "score"  
  5.   bottom: "label"  
  6.   top: "loss"  
  7. }  


score是输出层的名字,label是输入层两个名字中的一个,另一个是data。这都比较好理解,loss函数就是将网络输出和ground-truth标签做比较,然后计算loss。


我在网络定义中,把accuracy也删除了,多标签的accuracy和单标签是不一样的,之后可以单独写代码来进行统计。我在这一大步骤中讲的主要是针对单标签分类任务的修改,而对于输出单元个数以及其他参数的说明可以参考我的上一篇博客。



Solver的定义/编辑


其实,多标签的solver和单标签的并没有什么区别,同样参考本人上篇博客就能有一个大概的了解,这里还是把相应的solver文件贴一下:

[python] view plain copy
  1. base_lr: 0.001  
  2. display: 100  
  3. gamma: 0.1  
  4. iter_size: 1  
  5. lr_policy: "step"  
  6. stepsize: 5000  
  7. max_iter: 10000  
  8. momentum: 0.9  
  9. snapshot: 5000  
  10. snapshot_prefix: "/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Snapshot/caffe_weather_train"  
  11. test_interval: 200  
  12. test_iter: 20  
  13. test_net: "/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Files/valnet.prototxt"  
  14. train_net: "/root/wangzhignag/WeatherRecognition/Modified/MultiLabel/Files/trainnet.prototxt"  
  15. weight_decay: 0.0005  
  16. solver_mode: GPU  

主要就是把train_net、test_net定义文件的路径改一下,snapshot_prefix改一下。不特别指定优化算法,默认就是随机梯度下降方法。如果采用其他优化算法的话,一定注意会有一些特定的参数需要指定,请查阅其他资料进行了解。


网络模型的训练与测试


先来看两个跟计算准确率相关的函数:

[python] view plain copy
  1. def hamming_distance(gt, est):  
  2.     return sum([1 for (g, e) in zip(gt, est) if g == e]) / float(len(gt))  

这个函数虽然叫hamming_distance,但是和真正的海明距离是不一样的。从函数体中可以看出,这个函数就是计算了两个向量中,各个位置标签相同的次数占向量总长的一个比例。这个函数是为后面真正计算准确率做服务的。

[python] view plain copy
  1. def check_accuracy(net, num_batches, batch_size = 50):  
  2.     acc = 0.0  
  3.     for t in range(num_batches):  
  4.         net.forward()  
  5.         gts = net.blobs['label'].data  
  6.         ests = net.blobs['score'].data > 0  
  7.         for gt, est in zip(gts, ests): #for each ground truth and   
  8.             estimated label vector  
  9.             acc += hamming_distance(gt, est)  
  10.     return acc / (num_batches * batch_size)  


从函数名就可以看出,此函数是计算分类准确率的。方法也很简单直观,就是让网络模型在测试集上跑一遍,然后不断累加之前定义的hamming_distance,最后求一个平均就OK了。




最后就是训练网络的run_solver函数了,其实这个函数和上篇博客中单标签任务的run_solver十分相似,这里就简单地贴一下代码:

[python] view plain copy
  1. def run_solvers(niter, solver, disp_interval=100, test_interval=200,                  
  2.     test_iter=20):  
  3.     fig1,ax1=plt.subplots()  #used for draw  
  4.     fig2,ax2=plt.subplots()  #used for draw  
  5.   
  6.     train_loss=np.zeros(np.ceil(niter*1.0/disp_interval))  
  7.     test_acc=np.zeros(np.ceil(niter*1.0/test_interval))  
  8.     atom_train_loss = 0  
  9.     train_count, test_count = 00  
  10.     for it in range(1, niter+1):  
  11.         solver.step(1)  
  12.         atom_train_loss += solver.net.blobs['loss'].data  
  13.   
  14.         if it % disp_interval == 0:  
  15.             train_loss[train_count] = atom_train_loss/disp_interval  
  16.             atom_train_loss=0  
  17.             print '\n##########%d iteration train: loss=%.3f\n' %(it,   
  18.             train_loss[train_count])   
  19.             train_count += 1  
  20.   
  21.   
  22.         if it % test_interval == 0:  
  23.             test_acc[test_count] = check_accuracy(solver.test_nets[0],   
  24.             num_batches=20)  
  25.   
  26.             print '##########%d iteration Test: accuracy=%.3f\n' %(it,   
  27.             test_acc[test_count])   
  28.             test_count += 1  
  29.   
  30.             ################## Draw  
  31.             ax1.cla()  
  32.             ax1.set_title('Display Loss')  
  33.             ax1.set_xlabel('Iteration/100')  
  34.             ax1.set_ylabel('Loss')  
  35.             ax1.set_xlim(0,100)  
  36.             ax1.grid()  
  37.             ax1.plot(train_loss[:train_count],'r',label='train loss')  
  38.             ax1.legend(loc='best')  
  39.   
  40.             ax2.cla()  
  41.             ax2.set_title('Display Accuracy')  
  42.             ax2.set_xlabel('Iteration/100')  
  43.             ax2.set_ylabel('Accuracy')  
  44.             ax2.set_xlim(0,100)  
  45.             ax2.grid()                                                                                                                                                                                                           
  46.             ax2.plot(range(0,test_count*2,2),  
  47.                 test_acc[:test_count],'g',label='test accuracy')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
  48.             ax2.legend(loc='best')  
  49.             plt.pause(1)  
  50.   
  51.   
  52.     return train_loss, test_acc  


其实训练还是依靠solver的step()函数进行前向传播计算loss和后向传播计算梯度,然后再进行参数更新。不同于上篇博客里的run_solver函数,这里在测试阶段用到了刚刚定义的check_accuracy函数来计算测试集上的识别准确率。这里没有对测试集统计loss,也没有对训练集统计accuracy,因为整个跑一遍训练集的accuracy时间还是比较长的。不过,读者若是倾向于完备地输出这些数据,只要稍作修改也是完全可以实现的。






对于多标签分类任务,到此为止,基本上整个流程就讲完了,读者如果有不同的需求,可以在每个环节进行一定的修改。另外,多标签回归任务的处理其实和多标签分类任务非常像,流程都是一样的。只不过ground-truth不再是01标签,而是一系列连续的值,这个需要在第一步准备数据时就整理好。还有一点就是网络中的loss函数,此时SigmoidCrossEntropyLoss已经不能胜任多标签的回归任务,我们希望预测值和ground-truth越接近越好,EuclideanLoss是个合适的选择,它计算的是两个向量间的欧氏距离。最后,多标签回归问题也没有准确率而言了,EuclideanLoss可以直接作为评价标准,另外MSE(均方误差)也可以用来评测结果的好坏,实际上EuclideanLoss和MSE形式上也差不多,可以简单地认为它们就是一回事。



深度学习交流QQ群:116270156


阅读全文
0 0
原创粉丝点击