Matconvnet 训练自己的数据(使用现有模型)

来源:互联网 发布:国家电网考试知乎 编辑:程序博客网 时间:2024/04/30 21:37

1.准备数据

本文采用sort_1000数据集,共有10类,每类100张图片,图片大小为384x256,


附下载链接:

链接:http://pan.baidu.com/s/1miPulsO 密码:gt8g

整理数据,取每类的前80张为训练集,后20张作为验证集,准备trainLabel.txt与testLabel.txt,形式如下

300.jpg 4
301.jpg 4
302.jpg 4
303.jpg 4
304.jpg 4
305.jpg 4
306.jpg 4
307.jpg 4
308.jpg 4
309.jpg 4
310.jpg 4
311.jpg 4
312.jpg 4
313.jpg 4
314.jpg 4
315.jpg 4
316.jpg 4
317.jpg 4

classIndex.txt

1 face
2 sky
3 building
4 bus
5 dinosaur
6 elephant
7 flower
8 horse
9 snowmoutain
10 food

整个数据集结构如下:

train文件夹放所有的训练图片,test为验证集

2.编写函数

主函数 cnn_net.m

function [net, info] = cnn_net(varargin)
%CNN_DICNN Demonstrates fine-tuning a pre-trained CNN on imagenet dataset
 
run(fullfile(fileparts(mfilename('fullpath')), ...
   '..', 'matlab', 'vl_setupnn.m')) ;
% 修改读入文件夹的路径
opts.dataDir = fullfile('data','Sort_1000pics') ;
opts.expDir  = fullfile('exp', 'Sort_1000pics') ;
% 导入预训练的model
%opts.modelPath = fullfile('models','imagenet-vgg-f.mat');
opts.modelPath='../models/imagenet-vgg-f.mat'
[opts, varargin] = vl_argparse(opts, varargin) ;
 
opts.numFetchThreads = 12 ;
 
opts.lite = false ;
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');
opts.modelType = 'vgg-f' ;
opts.network = [] ;
opts.networkType = 'simplenn' ;
opts.train = struct() ;
opts.train.gpus = [1];
opts.train.batchSize = 8 ;
opts.train.numSubBatches = 1 ;
opts.train.learningRate = 1e-4 * [ones(1,10), 0.1*ones(1,5)];
 
opts = vl_argparse(opts, varargin) ;
if ~isfield(opts.train, 'gpus'), opts.train.gpus = [1]; end;
 
% -------------------------------------------------------------------------
%                                                             Prepare model
% -------------------------------------------------------------------------
net = load(opts.modelPath);
% 修改一下这个model
net = prepareDINet(net,opts);
% -------------------------------------------------------------------------
%                                                              Prepare data
% -------------------------------------------------------------------------
% 准备数据格式
if exist(opts.imdbPath,'file')
  imdb = load(opts.imdbPath) ;
else
  imdb = cnn_image_setup_data('dataDir', opts.dataDir, 'lite', opts.lite) ;
  mkdir(opts.expDir) ;
  save(opts.imdbPath, '-struct', 'imdb') ;
end
 
imdb.images.set = imdb.images.sets;
 
% Set the class names in the network
net.meta.classes.name = imdb.classes.name ;
net.meta.classes.description = imdb.classes.name ;
 
% % 求训练集的均值
imageStatsPath = fullfile(opts.expDir, 'imageStats.mat') ;
if exist(imageStatsPath)
  load(imageStatsPath, 'averageImage') ;
else
    averageImage = getImageStats(opts, net.meta, imdb) ;
    save(imageStatsPath, 'averageImage') ;
end
% % 用新的均值改变均值
net.meta.normalization.averageImage = averageImage;
% -------------------------------------------------------------------------
%                                                                     Learn
% -------------------------------------------------------------------------
% 索引训练集==1  和测试集==3
opts.train.train = find(imdb.images.set==1) ;
opts.train.val = find(imdb.images.set==3) ;
% 训练
[net, info] = cnn_train_dag(net, imdb, getBatchFn(opts, net.meta), ...
                      'expDir', opts.expDir, ...
                      opts.train) ;

% -------------------------------------------------------------------------
%                                                                    Deploy
% -------------------------------------------------------------------------
% 保存训练完的网络
net = cnn_imagenet_deploy(net) ;
modelPath = fullfile(opts.expDir, 'net-deployed.mat');
 
net_ = net.saveobj() ;
save(modelPath, '-struct', 'net_') ;
clear net_ ;
 
% -------------------------------------------------------------------------
function fn = getBatchFn(opts, meta)
% -------------------------------------------------------------------------
useGpu = numel(opts.train.gpus) > 0 ;
 
bopts.numThreads = opts.numFetchThreads ;
bopts.imageSize = meta.normalization.imageSize ;
bopts.border = meta.normalization.border ;
% bopts.averageImage = [];
bopts.averageImage = meta.normalization.averageImage ;
% bopts.rgbVariance = meta.augmentation.rgbVariance ;
% bopts.transformation = meta.augmentation.transformation ;
 
fn = @(x,y) getDagNNBatch(bopts,useGpu,x,y) ;
 
 
% -------------------------------------------------------------------------
function inputs = getDagNNBatch(opts, useGpu, imdb, batch)
% -------------------------------------------------------------------------
% 判断读入数据为训练还是测试
for i = 1:length(batch)
    if imdb.images.set(batch(i)) == 1 %1为训练索引文件夹
        images(i) = strcat([imdb.imageDir.train filesep] , imdb.images.name(batch(i)));
    else
        images(i) = strcat([imdb.imageDir.test filesep] , imdb.images.name(batch(i)));
    end
end;
isVal = ~isempty(batch) && imdb.images.set(batch(1)) ~= 1 ;
 
if ~isVal
  % training
  im = cnn_imagenet_get_batch(images, opts, ...
                              'prefetch', nargout == 0) ;
else
  % validation: disable data augmentation
  im = cnn_imagenet_get_batch(images, opts, ...
                            'prefetch', nargout == 0, ...
                            'transformation', 'none') ;

end
 
if nargout > 0
  if useGpu
    im = gpuArray(im) ;
  end
  labels = imdb.images.label(batch) ;
  inputs = {'input', im, 'label', labels} ;
end
 
% 求训练样本的均值
% -------------------------------------------------------------------------
function averageImage = getImageStats(opts, meta, imdb)
% -------------------------------------------------------------------------
train = find(imdb.images.set == 1) ;
batch = 1:length(train);
fn = getBatchFn(opts, meta) ;
train = train(1: 100: end);
avg = {};
for i = 1:length(train)
    temp = fn(imdb, batch(train(i):train(i)+99)) ;
    temp = temp{2};
    avg{end+1} = mean(temp, 4) ;
end
 
averageImage = mean(cat(4,avg{:}),4) ;
% 将GPU格式的转化为cpu格式的保存起来(如果有用GPU)
averageImage = gather(averageImage);

说明:这里的模型采用imagenet-vgg-f,需要另外下载,存放在models文件夹下

载入数据 cnn_image_setup_data.m

function imdb = cnn_image_setup_data(varargin)
 
opts.dataDir = fullfile('data','Sort_1000pics') ;
opts.lite = false ;
opts = vl_argparse(opts, varargin) ;
 
% ------------------------------------------------------------------------
%                                                  Load categories metadata
% -------------------------------------------------------------------------
 
metaPath = fullfile(opts.dataDir, 'classIndex.txt') ;
 
fprintf('using metadata %s\n', metaPath) ;
tmp = importdata(metaPath);
nCls = numel(tmp);
% 判断类别与设定的是否一样 10为样本的类别总数(自己的数据集需要修改)
if nCls ~= 10
  error('Wrong meta file %s',metaPath);
end
% 将名字分离出来
cats = cell(1,nCls);
for i=1:numel(tmp)
  t = strsplit(tmp{i});
  cats{i} = t{2};
end
% 数据集文件夹选择
imdb.classes.name = cats ;
imdb.imageDir.train = fullfile(opts.dataDir, 'train') ;
imdb.imageDir.test = fullfile(opts.dataDir, 'test') ;
 
%% -----------------------------------------------------------------
%                                              load image names and labels
% -------------------------------------------------------------------------
 
name = {};
labels = {} ;
imdb.images.sets = [] ;
%%
fprintf('searching training images ...\n') ;
% 导入训练类别标签
train_label_path = fullfile(opts.dataDir, 'trainLable.txt') ;
train_label_temp = importdata(train_label_path);
temp_l = train_label_temp.data;
for i=1:numel(temp_l)
    train_label{i} = temp_l(i);
end
if length(train_label) ~= length(dir(fullfile(imdb.imageDir.train, '*.jpg')))
    error('training data is not equal to its label!!!');
end
 
temp_n=train_label_temp.textdata;
for i=1:numel(temp_n)
    name{end+1}=temp_n{i};
    labels{end+1} = train_label{i} ;
     imdb.images.sets(end+1) = 1;
end
fprintf('searching testing images ...\n') ;
% 导入测试类别标签
test_label_path = fullfile(opts.dataDir, 'testLable.txt') ;
test_label_temp = importdata(test_label_path);
temp_l = test_label_temp.data;
for i=1:numel(temp_l)
    test_label{i} = temp_l(i);
end
if length(test_label) ~= length(dir(fullfile(imdb.imageDir.test, '*.jpg')))
    error('testing data is not equal to its label!!!');
end

temp_n=test_label_temp.textdata;
for i=1:numel(temp_n)
    name{end+1}=temp_n{i};
    labels{end+1} = test_label{i} ;
     imdb.images.sets(end+1) = 3;
end
labels = horzcat(labels{:}) ;
imdb.images.id = 1:numel(name) ;
imdb.images.name = name ;
imdb.images.label = labels ;

打包图像 cnn_imagenet_get_batch.m

function imo = cnn_imagenet_get_batch(images, varargin)
% CNN_IMAGENET_GET_BATCH  Load, preprocess, and pack images for CNN evaluation

opts.imageSize = [227, 227] ;
opts.border = [29, 29] ;
opts.keepAspect = true ;
opts.numAugments = 1 ;
opts.transformation = 'none' ;
opts.averageImage = [] ;
opts.rgbVariance = zeros(0,3,'single') ;
opts.interpolation = 'bilinear' ;
opts.numThreads = 1 ;
opts.prefetch = false ;
opts = vl_argparse(opts, varargin);

% fetch is true if images is a list of filenames (instead of
% a cell array of images)
fetch = numel(images) >= 1 && ischar(images{1}) ;

% prefetch is used to load images in a separate thread
prefetch = fetch & opts.prefetch ;

if prefetch
  vl_imreadjpeg(images, 'numThreads', opts.numThreads, 'prefetch') ;
  imo = [] ;
  return ;
end
if fetch
  im = vl_imreadjpeg(images,'numThreads', opts.numThreads) ;
else
  im = images ;
end

tfs = [] ;
switch opts.transformation
  case 'none'
    tfs = [
      .5 ;
      .5 ;
       0 ] ;
  case 'f5'
    tfs = [...
      .5 0 0 1 1 .5 0 0 1 1 ;
      .5 0 1 0 1 .5 0 1 0 1 ;
       0 0 0 0 0  1 1 1 1 1] ;
  case 'f25'
    [tx,ty] = meshgrid(linspace(0,1,5)) ;
    tfs = [tx(:)' ; ty(:)' ; zeros(1,numel(tx))] ;
    tfs_ = tfs ;
    tfs_(3,:) = 1 ;
    tfs = [tfs,tfs_] ;
  case 'stretch'
  otherwise
    error('Uknown transformations %s', opts.transformation) ;
end
[~,transformations] = sort(rand(size(tfs,2), numel(images)), 1) ;

if ~isempty(opts.rgbVariance) && isempty(opts.averageImage)
  opts.averageImage = zeros(1,1,3) ;
end
if numel(opts.averageImage) == 3
  opts.averageImage = reshape(opts.averageImage, 1,1,3) ;
end

imo = zeros(opts.imageSize(1), opts.imageSize(2), 3, ...
            numel(images)*opts.numAugments, 'single') ;

si = 1 ;
for i=1:numel(images)

  % acquire image
  if isempty(im{i})
    imt = imread(images{i}) ;
    imt = single(imt) ; % faster than im2single (and multiplies by 255)
  else
    imt = im{i} ;
  end
  if size(imt,3) == 1
    imt = cat(3, imt, imt, imt) ;
  end

  % resize
  w = size(imt,2) ;
  h = size(imt,1) ;
  factor = [(opts.imageSize(1)+opts.border(1))/h ...
            (opts.imageSize(2)+opts.border(2))/w];

  if opts.keepAspect
    factor = max(factor) ;
  end
  if any(abs(factor - 1) > 0.0001)
    imt = imresize(imt, ...
                   'scale', factor, ...
                   'method', opts.interpolation) ;
  end

  % crop & flip
  w = size(imt,2) ;
  h = size(imt,1) ;
  for ai = 1:opts.numAugments
    switch opts.transformation
      case 'stretch'
        sz = round(min(opts.imageSize(1:2)' .* (1-0.1+0.2*rand(2,1)), [h;w])) ;
        dx = randi(w - sz(2) + 1, 1) ;
        dy = randi(h - sz(1) + 1, 1) ;
        flip = rand > 0.5 ;
      otherwise
        tf = tfs(:, transformations(mod(ai-1, numel(transformations)) + 1)) ;
        sz = opts.imageSize(1:2) ;
        dx = floor((w - sz(2)) * tf(2)) + 1 ;
        dy = floor((h - sz(1)) * tf(1)) + 1 ;
        flip = tf(3) ;
    end
    sx = round(linspace(dx, sz(2)+dx-1, opts.imageSize(2))) ;
    sy = round(linspace(dy, sz(1)+dy-1, opts.imageSize(1))) ;
    if flip, sx = fliplr(sx) ; end

    if ~isempty(opts.averageImage)
      offset = opts.averageImage ;
      if ~isempty(opts.rgbVariance)
        offset = bsxfun(@plus, offset, reshape(opts.rgbVariance * randn(3,1), 1,1,3)) ;
      end
      imo(:,:,:,si) = bsxfun(@minus, imt(sy,sx,:), offset) ;
    else
      imo(:,:,:,si) = imt(sy,sx,:) ;
    end
    si = si + 1 ;
  end
end

修改网络的最后一层 prepareDINet.m

% -------------------------------------------------------------------------
function net = prepareDINet(net,opts)
% -------------------------------------------------------------------------
% replace fc8
fc8l = cellfun(@(a) strcmp(a.name, 'fc8'), net.layers)==1;
 
%%  note: 下面这个是类别数,一定要和自己的类别数吻合(这里为10类)
nCls = 10;
sizeW = size(net.layers{fc8l}.weights{1});
% 将权重初始化
if sizeW(4)~=nCls
  net.layers{fc8l}.weights = {zeros(sizeW(1),sizeW(2),sizeW(3),nCls,'single'), ...
    zeros(1, nCls, 'single')};
end
 
% change loss  添加一个loss层用于训练
net.layers{end} = struct('name','loss', 'type','softmaxloss') ;
 
% convert to dagnn dagnn网络,还需要添加下面这几层才能训练
net = dagnn.DagNN.fromSimpleNN(net, 'canonicalNames', true) ;
 
net.addLayer('top1err', dagnn.Loss('loss', 'classerror'), ...
    {'prediction','label'}, 'top1err') ;
net.addLayer('top5err', dagnn.Loss('loss', 'topkerror', ...
    'opts', {'topK',5}), ...
{'prediction','label'}, 'top5err') ;

保存网络 cnn_imagenet_deploy.m

function net = cnn_imagenet_deploy(net)
%CNN_IMAGENET_DEPLOY  Deploy a CNN

isDag = isa(net, 'dagnn.DagNN') ;
if isDag
  dagRemoveLayersOfType(net, 'dagnn.Loss') ;
  dagRemoveLayersOfType(net, 'dagnn.DropOut') ;
else
  net = simpleRemoveLayersOfType(net, 'softmaxloss') ;
  net = simpleRemoveLayersOfType(net, 'dropout') ;
end

if isDag
  net.addLayer('prob', dagnn.SoftMax(), 'prediction', 'prob', {}) ;
else
  net.layers{end+1} = struct('name', 'prob', 'type', 'softmax') ;
end

if isDag
  dagMergeBatchNorm(net) ;
  dagRemoveLayersOfType(net, 'dagnn.BatchNorm') ;
else
  net = simpleMergeBatchNorm(net) ;
  net = simpleRemoveLayersOfType(net, 'bnorm') ;
end

if ~isDag
  net = simpleRemoveMomentum(net) ;
end

% Switch to use MatConvNet default memory limit for CuDNN (512 MB)
if ~isDag
  for l = simpleFindLayersOfType(net, 'conv')
    net.layers{l}.opts = removeCuDNNMemoryLimit(net.layers{l}.opts) ;
  end
else
  for name = dagFindLayersOfType(net, 'dagnn.Conv')
    l = net.getLayerIndex(char(name)) ;
    net.layers(l).block.opts = removeCuDNNMemoryLimit(net.layers(l).block.opts) ;
  end
end

% -------------------------------------------------------------------------
function opts = removeCuDNNMemoryLimit(opts)
% -------------------------------------------------------------------------
remove = false(1, numel(opts)) ;
for i = 1:numel(opts)
  if isstr(opts{i}) && strcmp(lower(opts{i}), 'CudnnWorkspaceLimit')
    remove([i i+1]) = true ;
  end
end
opts = opts(~remove) ;

% -------------------------------------------------------------------------
function net = simpleRemoveMomentum(net)
% -------------------------------------------------------------------------
for l = 1:numel(net.layers)
  if isfield(net.layers{l}, 'momentum')
    net.layers{l} = rmfield(net.layers{l}, 'momentum') ;
  end
end

% -------------------------------------------------------------------------
function layers = simpleFindLayersOfType(net, type)
% -------------------------------------------------------------------------
layers = find(cellfun(@(x)strcmp(x.type, type), net.layers)) ;

% -------------------------------------------------------------------------
function net = simpleRemoveLayersOfType(net, type)
% -------------------------------------------------------------------------
layers = simpleFindLayersOfType(net, type) ;
net.layers(layers) = [] ;

% -------------------------------------------------------------------------
function layers = dagFindLayersWithOutput(net, outVarName)
% -------------------------------------------------------------------------
layers = {} ;
for l = 1:numel(net.layers)
  if any(strcmp(net.layers(l).outputs, outVarName))
    layers{1,end+1} = net.layers(l).name ;
  end
end

% -------------------------------------------------------------------------
function layers = dagFindLayersOfType(net, type)
% -------------------------------------------------------------------------
layers = [] ;
for l = 1:numel(net.layers)
  if isa(net.layers(l).block, type)
    layers{1,end+1} = net.layers(l).name ;
  end
end

% -------------------------------------------------------------------------
function dagRemoveLayersOfType(net, type)
% -------------------------------------------------------------------------
names = dagFindLayersOfType(net, type) ;
for i = 1:numel(names)
  layer = net.layers(net.getLayerIndex(names{i})) ;
  net.removeLayer(names{i}) ;
  net.renameVar(layer.outputs{1}, layer.inputs{1}, 'quiet', true) ;
end

% -------------------------------------------------------------------------
function dagMergeBatchNorm(net)
% -------------------------------------------------------------------------
names = dagFindLayersOfType(net, 'dagnn.BatchNorm') ;
for name = names
  name = char(name) ;
  layer = net.layers(net.getLayerIndex(name)) ;

  % merge into previous conv layer
  playerName = dagFindLayersWithOutput(net, layer.inputs{1}) ;
  playerName = playerName{1} ;
  playerIndex = net.getLayerIndex(playerName) ;
  player = net.layers(playerIndex) ;
  if ~isa(player.block, 'dagnn.Conv')
    error('Batch normalization cannot be merged as it is not preceded by a conv layer.') ;
  end

  % if the convolution layer does not have a bias,
  % recreate it to have one
  if ~player.block.hasBias
    block = player.block ;
    block.hasBias = true ;
    net.renameLayer(playerName, 'tmp') ;
    net.addLayer(playerName, ...
                 block, ...
                 player.inputs, ...
                 player.outputs, ...
                 {player.params{1}, sprintf('%s_b',playerName)}) ;
    net.removeLayer('tmp') ;
    playerIndex = net.getLayerIndex(playerName) ;
    player = net.layers(playerIndex) ;
    biases = net.getParamIndex(player.params{2}) ;
    net.params(biases).value = zeros(block.size(4), 1, 'single') ;
  end

  filters = net.getParamIndex(player.params{1}) ;
  biases = net.getParamIndex(player.params{2}) ;
  multipliers = net.getParamIndex(layer.params{1}) ;
  offsets = net.getParamIndex(layer.params{2}) ;
  moments = net.getParamIndex(layer.params{3}) ;

  [filtersValue, biasesValue] = mergeBatchNorm(...
    net.params(filters).value, ...
    net.params(biases).value, ...
    net.params(multipliers).value, ...
    net.params(offsets).value, ...
    net.params(moments).value) ;

  net.params(filters).value = filtersValue ;
  net.params(biases).value = biasesValue ;
end

% -------------------------------------------------------------------------
function net = simpleMergeBatchNorm(net)
% -------------------------------------------------------------------------

for l = 1:numel(net.layers)
  if strcmp(net.layers{l}.type, 'bnorm')
    if ~strcmp(net.layers{l-1}.type, 'conv')
      error('Batch normalization cannot be merged as it is not preceded by a conv layer.') ;
    end
    [filters, biases] = mergeBatchNorm(...
      net.layers{l-1}.weights{1}, ...
      net.layers{l-1}.weights{2}, ...
      net.layers{l}.weights{1}, ...
      net.layers{l}.weights{2}, ...
      net.layers{l}.weights{3}) ;
    net.layers{l-1}.weights = {filters, biases} ;
  end
end

% -------------------------------------------------------------------------
function [filters, biases] = mergeBatchNorm(filters, biases, multipliers, offsets, moments)
% -------------------------------------------------------------------------
% wk / sqrt(sigmak^2 + eps)
% bk - wk muk / sqrt(sigmak^2 + eps)
a = multipliers(:) ./ moments(:,2) ;
b = offsets(:) - moments(:,1) .* a ;
biases(:) = biases(:) + b(:) ;
sz = size(filters) ;
numFilters = sz(4) ;
filters = reshape(bsxfun(@times, reshape(filters, [], numFilters), a'), sz) ;

测试正确率 test_accuracy.m

clc
clear
% 导入model
net1 = dagnn.DagNN.loadobj(load('exp\Sort_1000pics\net-deployed.mat')) ;
net1.mode = 'test' ;
% 导入准备数据
imdb = load('exp\Sort_1000pics\imdb.mat') ;
 
opts.dataDir = fullfile('data','Sort_1000pics') ;
opts.expDir  = fullfile('exp', 'Sort_1000pics') ;
% 找到训练与测试集
opts.train.train = find(imdb.images.sets==1) ;
opts.train.val = find(imdb.images.sets==3) ;
 
for i = 1:length(opts.train.val)
    
    index = opts.train.val(i);
    label = imdb.images.label(index);
    % 读取测试的样本
    im_ =  imread(fullfile(imdb.imageDir.test,imdb.images.name{index}));
%     figure,imshow(im_);
    im_ = single(im_);
    im_ = imresize(im_, net1.meta.normalization.imageSize(1:2)) ;
    im_ = bsxfun(@minus, im_, net1.meta.normalization.averageImage) ;
    % 测试
    net1.eval({'input',im_}) ;
    scores = net1.vars(net1.getVarIndex('prob')).value ;
    scores = squeeze(gather(scores)) ;
 
    [bestScore, best] = max(scores) ;
%     i,scores,best
    truth(i) = label;
    pre(i) = best;
end
% 计算准确率
accurcy = length(find(pre==truth))/length(truth);
disp(['accurcy = ',num2str(accurcy*100),'%']);

说明这里为了省事直接将验证集作为测试集

3.运行程序

本实验采用win764位系统,matlab2014a,显卡为k2000

为了省时,迭代次数为5,在cnn_train_bag.m,可修改迭代次数。

程序运行截图为

运行test_accuracy.m

accurcy = 90.5%

附项目代码,含数据集

链接:http://pan.baidu.com/s/1gfj6QJL 密码:gc07


阅读全文
0 0
原创粉丝点击