caffe学习笔记4-matcaffe训练与测试

来源:互联网 发布:虚拟局域网 软件 推荐 编辑:程序博客网 时间:2024/05/16 17:15
.m文件流程(训练或者测试)
1.
   添加路径 caffe/matlab 使得 Matlab 可以使用 matcaffe, +caffe文件夹下都是matcaffe的.m接口,可用matlab操作caffe网络
   if exist('../+caffe', 'dir')
      addpath('..');

2.设置caffe cpu/gpu 模式(在测试或者训练之前。.m文件中)
   if exist('use_gpu', 'var') && use_gpu
      caffe.set_mode_gpu();
      gpu_id = 0;  % we will use the first gpu in this demo
      caffe.set_device(gpu_id);
   else
      caffe.set_mode_cpu(); 
   end

3.后面就是初始化网络,进行训练或者测试。
   用已有模型进行测试流程(测试以分类为例)
  model_dir = '../../models/bvlc_reference_caffenet/';  //实际文件路径model = './models/bvlc_reference_caffenet/deploy.prototxt';
  net_model = [model_dir 'deploy.prototxt'];         
  net_weights = [model_dir 'bvlc_reference_caffenet.caffemodel']; 
  phase = 'test'   
  net = caffe.Net(net_model, net_weights, phase); //创建网络并加载权值 
  或者:net = caffe.Net(model, 'test'); % 创建网络,但不加载权值
              net.copy_from(weights); % 加载权值
  prepare_image() //数据预处理(格式+冗余),自己定义
  input_data = {prepare_image(im)};  //装载数据,等价net.blobs('data').set_data(prepare_image(im));用法
  scores = net.forward(input_data);  //前向计算
  //提取出最大的score(概率)以及对应的标签号
  scores = scores{1};        //等价prob = net.blobs('prob').get_data();用法//计算之后再读取原块的数据,最后一层为prob
  scores = mean(scores, 2);  //取所有分类结果的平均值 

  [~, maxlabel] = max(scores); //找到最大概率对应的标签号




0 0
原创粉丝点击