Matconvnet框架中采用自己的softmaxloss损失函数代码

来源:互联网 发布:凯立德导航端口工具 编辑:程序博客网 时间:2024/06/06 05:46

        Matconvnet框架中采用自己的softmaxloss损失函数代码。主要涉及的模块是vl_nnsoftmaxloss函数和processEpoch函数。vl_nnsoftmaxloss函数中实现了自己的softmaxloss代码,相关的算法推导请见softmaxloss损失函数的算法推导,注意进行softmaxlos的相关计算前要减去神经网络输出的最大值。processEpoch函数中的[im,labels] = params.getBatch(params.imdb, batch)被删除,采用自己编写的图像样本读取代码读取图像样本。

 

1.      vl_nnsoftmaxloss

function y =vl_nnsoftmaxloss(x,c,dzdy)

if isa(x, 'gpuArray')

  switchclassUnderlying(x) ;

    case 'single', cast = @(z) single(z) ;

    case 'double', cast = @(z) double(z) ;

  end

else

  switchclass(x)

    case 'single', cast = @(z) single(z) ;

    case 'double', cast = @(z) double(z) ;

  end

end

v_c=squeeze(c)

v_c=c

v_x=squeeze(x)

 

% new_max_x=max(v_x)

% new_x=bsxfun(@minus,v_x,new_max_x)

% new_exp=exp(new_x)

% new_sum_exp=sum(new_exp)

% for k=1:numel(new_sum_exp)

%    new_loss(k)=new_sum_exp(k)-new_exp(v_c(k),k)

% end

% new_loss_total=sum(new_loss+new_max_x)

numClasses=3;

trainLabels=c;

groundTruth=bsxfun(@eq,repmat(trainLabels,numClasses,1),(1:1:numClasses)')

 

y_error=groundTruth-v_x;

[sw_error,sh_error]=size(y_error);

y_error_out=reshape(y_error,[1,1,sw_error,sh_error])

 

new_max_x=max(v_x);

new_x=bsxfun(@minus,v_x,new_max_x);

new_exp=exp(new_x);

new_sum_exp=sum(new_exp);

fork=1:numel(new_sum_exp)

%   new_sita(:,k)=new_exp(:,k)/new_sum_exp(k);

     new_sita(:,k)=new_exp(:,k)/new_sum_exp(k);

      new_sita(v_c(k),k)=new_sita(v_c(k),k)-1;

end

new_sita=new_sita

[sw,sh]=size(new_sita);

v_y_my=reshape(new_sita,[1,1,sw,sh])

 

%X = X + 1e-6 ;

sz =[size(x,1) size(x,2) size(x,3) size(x,4)] ;

% v_size_x=size(x)

 

ifnumel(c) == sz(4)

  % one labelper image

  c = reshape(c, [1 1 1 sz(4)]) ;

end

ifsize(c,1) == 1 & size(c,2) == 1

  c = repmat(c, [sz(1) sz(2)]) ;

end

% v_c1=c

 

% one label per spatial location

sz_ =[size(c,1) size(c,2) size(c,3) size(c,4)] ;

assert(isequal(sz_,[sz(1) sz(2) sz_(3) sz(4)])) ;

assert(sz_(3)==1| sz_(3)==2) ;

% v_sz_=sz_

 

% class c = 0 skips a spatial location

 

mass =cast(c(:,:,1,:) > 0) ;

v_mass=squeeze(mass);

if(~isempty(find(v_mass~=1)))

    disp('*******Inthe vl_nnsoftmaxloss function!*****')

    pause

end

% pause

% v_mass=mass

if sz_(3)== 2

  % the secondchannel of c (if present) is used as weights

  mass = mass .* c(:,:,2,:) ;

  c(:,:,2,:) = [] ;

  disp('*******unexpectedbehaviro***********************')

  pause

end

 

% convert to indexes

c = c -1 ;

c_ =0:numel(c)-1 ;

c_ = 1 +...

  mod(c_, sz(1)*sz(2)) + ...

  (sz(1)*sz(2)) * max(c(:), 0)' + ...

  (sz(1)*sz(2)*sz(3)) * floor(c_/(sz(1)*sz(2)));

% v_c_=c_

 

% compute softmaxloss

xmax =max(x,[],3) ;

ex =exp(bsxfun(@minus, x, xmax)) ;

 

%n = sz(1)*sz(2) ;

if nargin<= 2

%   disp('In thebranch 1!!')

  t = xmax + log(sum(ex,3)) - reshape(x(c_),[sz(1:2) 1 sz(4)]) ;

  y = sum(sum(sum(mass .* t,1),2),4) ;

  v_y1=y

else

  y = bsxfun(@rdivide, ex, sum(ex,3)) ;

%   v_y1=y

  y(c_) = y(c_) - 1;

%   y = bsxfun(@times,y, bsxfun(@times, mass, dzdy)) ;

  y=v_y_my;

  v_y2=squeeze(y)

  v_y3=y

end

 

2.     processEpoch

%-------------------------------------------------------------------------

function [net,state] = processEpoch(net, state, params, mode)

%-------------------------------------------------------------------------

ifisempty(state) || isempty(state.momentum)

  for i = 1:numel(net.layers)

    for j =1:numel(net.layers{i}.weights)

     state.momentum{i}{j} = 0 ;

    end

  end

end

 

% move CNN  to GPU as needed

numGpus = numel(params.gpus) ;

if numGpus>= 1

  net =vl_simplenn_move(net, 'gpu') ;

  for i = 1:numel(state.momentum)

    for j = 1:numel(state.momentum{i})

     state.momentum{i}{j} = gpuArray(state.momentum{i}{j}) ;

    end

  end

end

% disp('**********TP 2 of processEpochfunction*********** ')

if numGpus> 1

 parserv = ParameterServer(params.parameterServer) ;

 vl_simplenn_start_parserv(net, parserv) ;

else

 parserv = [] ;

end

% disp('**********TP 3 of processEpochfunction*********** ')

% profile

ifparams.profile

  if numGpus <= 1

   profile clear ;

   profile on ;

  else

   mpiprofile reset ;

   mpiprofile on ;

  end

end

% disp('**********TP 4 of processEpochfunction*********** ')

 

subset = params.(mode) ;

num = 0 ;

stats.num = 0 ; %return something even if subset = []

stats.time = 0 ;

adjustTime = 0 ;

res = [] ;

error = [] ;

 

start = tic ;

 

% params.batchSize=128;

% disp('**********TP 5 of processEpochfunction*********** ')

fort=1:params.batchSize:numel(subset)

 fprintf('%s: epoch %02d: %3d/%3d:', mode, params.epoch,...

         fix((t-1)/params.batchSize)+1,ceil(numel(subset)/params.batchSize)) ;

 batchSize = min(params.batchSize, numel(subset) - t + 1) ;

% disp('**********TP 6 of processEpochfunction*********** ')

  for s=1:params.numSubBatches

    % get this image batch and prefetch the next

   batchStart = t + (labindex-1) + (s-1) * numlabs ;

   batchEnd = min(t+params.batchSize-1, numel(subset)) ;

   batch = subset(batchStart : params.numSubBatches * numlabs : batchEnd) ;

    num= num + numel(batch) ;

    if numel(batch) == 0, continue ; end

%    [im, labels] = params.getBatch(params.imdb, batch) ;

       

  d=dir(fullfile('sample_imgs'));

   for i=3:numel(d)

     v_d=d(i).name;

     im(:,:,:,i-2)=single(imread(['sample_imgs/',v_d]));

   end

  labels=[3 1 2];

 

%     v_size_im=size(im)

%    v_labels=labels

%    

%    figure

%    subplot(2,3,1)

%    imshow(uint8(im(:,:,:,1)));

%    subplot(2,3,2)

%    imshow(uint8(im(:,:,:,2)));

%    subplot(2,3,3)

%    imshow(uint8(im(:,:,:,3)));

%    subplot(2,3,4)

%     imshow(uint8(im(:,:,:,4)));

%    subplot(2,3,5)

%    imshow(uint8(im(:,:,:,5)));

%    pause

%    v_res_1=res

%    v_dzdy=dzdy

   

%     v_size_params_im=size(im)

% disp('**********TP 8 of processEpochfunction*********** ')

% pause

%    if params.prefetch

%      if s == params.numSubBatches

%         batchStart = t + (labindex-1) +params.batchSize ;

%         batchEnd = min(t+2*params.batchSize-1,numel(subset)) ;

%      else

%         batchStart = batchStart + numlabs ;

%      end

%      nextBatch = subset(batchStart : params.numSubBatches * numlabs :batchEnd) ;

%      params.getBatch(params.imdb, nextBatch) ;

%    end

%

%    if numGpus >= 1

%      im = gpuArray(im) ;

%    end

 

    if strcmp(mode, 'train')

     dzdy = 1 ;

     evalMode = 'normal' ;

    else

     dzdy = [] ;

     evalMode = 'test' ;

    end

   net.layers{end}.class = labels ;   

%    v_size_of_im=size(im)

  

 

    res= vl_simplenn(net, im, dzdy, res) ;

%    v_res_2=res

 

%    res = vl_simplenn(net, im, dzdy, res, ...

%                       'accumulate', s ~= 1,...

%                       'mode', evalMode, ...

%                       'conserveMemory',params.conserveMemory, ...

%                       'backPropDepth',params.backPropDepth, ...

%                      'sync',params.sync, ...

%                       'cudnn', params.cudnn,...

%                       'parameterServer',parserv, ...

%                       'holdOn', s <params.numSubBatches) ;

 

    % accumulate errors

   error = sum([error, [...

     sum(double(gather(res(end).x))) ;

     reshape(params.errorFunction(params, labels, res),[],1) ; ]],2) ;

  end

 

  % accumulate gradient

  if strcmp(mode, 'train')

    if ~isempty(parserv),parserv.sync() ;end

   [net, res, state] = accumulateGradients(net, res, state, params,batchSize, parserv) ;

  end

 

  % get statistics

  time= toc(start) + adjustTime ;

 batchTime = time - stats.time ;

  stats= extractStats(net, params, error / num) ;

 stats.num = num ;

  stats.time= time ;

 currentSpeed = batchSize / batchTime ;

 averageSpeed = (t + batchSize - 1) / time ;

  if t == 3*params.batchSize + 1

    % compensate for the first three iterations, which areoutliers

   adjustTime = 4*batchTime - time ;

   stats.time = time + adjustTime ;

  end

 

 fprintf(' %.1f (%.1f) Hz',averageSpeed, currentSpeed) ;

  for f =setdiff(fieldnames(stats)', {'num','time'})

    f =char(f) ;

   fprintf(' %s: %.3f', f,stats.(f)) ;

  end

 fprintf('\n') ;

 

  % collect diagnostic statistics

  if strcmp(mode, 'train') &&params.plotDiagnostics

   switchFigure(2) ; clf ;

   diagn = [res.stats] ;

   diagnvar = horzcat(diagn.variation) ;

   diagnpow = horzcat(diagn.power) ;

   subplot(2,2,1) ; barh(diagnvar) ;

   set(gca,'TickLabelInterpreter','none', ...

      'YTick', 1:numel(diagnvar),...

      'YTickLabel',horzcat(diagn.label),...

      'YDir', 'reverse', ...

      'XScale', 'log', ...

      'XLim', [1e-5 1], ...

      'XTick', 10.^(-5:1)) ;

   grid on ;

   subplot(2,2,2) ; barh(sqrt(diagnpow)) ;

   set(gca,'TickLabelInterpreter','none', ...

      'YTick', 1:numel(diagnpow),...

      'YTickLabel',{diagn.powerLabel},...

      'YDir', 'reverse', ...

      'XScale', 'log', ...

      'XLim', [1e-5 1e5], ...

      'XTick', 10.^(-5:5)) ;

   grid on ;

   subplot(2,2,3); plot(squeeze(res(end-1).x)) ;

   drawnow ;

  end

end

 

% Save back to state.

state.stats.(mode) = stats ;

ifparams.profile

  if numGpus <= 1

   state.prof.(mode) = profile('info') ;

   profile off ;

  else

   state.prof.(mode) = mpiprofile('info');

   mpiprofile off ;

  end

end

if~params.saveMomentum

 state.momentum = [] ;

else

  for i = 1:numel(state.momentum)

    for j = 1:numel(state.momentum{i})

     state.momentum{i}{j} = gather(state.momentum{i}{j}) ;

    end

  end

end

 

net = vl_simplenn_move(net, 'cpu') ;

 

阅读全文
1 0