Caffe实战Day5-使用opencv调用caffe模型进行分类

来源:互联网 发布:阿里云服务器怎么租赁 编辑:程序博客网 时间:2024/06/07 09:33
通过前面的文章,我们已经使用caffe训练了一个模型,下面我们在opencv中使用模型进行预测吧!
环境:OpenCV 3.3+VS2017
准备好三个文件:deploy.prototxt、caffemdel文件、标签文件labels.txt,建议大家按照前面的文章生成相应的文件,因为格式不同,可能程序运行会有错误。
1、修改deploy.prototxt文件
只需将输入层的格式修改一下:

name: "CaffeNet"layer {  name: "data"  type: "Input"  top: "data"  input_param { shape: { dim: 10 dim: 3 dim: 227 dim: 227 } }}
修改为:
name: "CaffeNet"input: "data"input_dim: 10 input_dim: 3 input_dim: 227 input_dim: 227
即可。
2、我将三个文件命名为:caffenet.prototxt、caffenet.caffemodel、labels.txt
3、VS新建工程,复制下面代码:

#include <opencv2/dnn.hpp>#include <opencv2/imgproc.hpp>#include <opencv2/highgui.hpp>#include <opencv2/core/utils/trace.hpp>using namespace cv;using namespace cv::dnn;#include <fstream>#include <iostream>#include <cstdlib>using namespace std;//寻找出概率最高的一类static void getMaxClass(const Mat &probBlob, int *classId, double *classProb){    Mat probMat = probBlob.reshape(1, 1);    Point classNumber;    minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);    *classId = classNumber.x;}//从标签文件读取分类 空格为标志static std::vector<String> readClassNames(const char *filename = "labels.txt"){    std::vector<String> classNames;    std::ifstream fp(filename);    if (!fp.is_open())    {        std::cerr << "File with classes labels not found: " << filename << std::endl;        exit(-1);    }    std::string name;while (!fp.eof()){std::getline(fp, name);if (name.length())classNames.push_back(name.substr(name.find(' ') + 1));}    fp.close();    return classNames;}//主程序int main(int argc, char **argv){//初始化    CV_TRACE_FUNCTION();//读取模型参数和模型结构文件    String modelTxt = "caffenet.prototxt";    String modelBin = "caffenet.caffemodel";//读取图片    String imageFile = (argc > 1) ? argv[1] : "test0.jpg";    //合成网络    Net net = dnn::readNetFromCaffe(modelTxt, modelBin);//判断网络是否生成成功    if (net.empty())    {        std::cerr << "Can't load network by using the following files: " << std::endl;        exit(-1);    }cerr << "net read successfully" << endl;//读取图片    Mat img = imread(imageFile);imshow("image", img);    if (img.empty())    {        std::cerr << "Can't read image from the file: " << imageFile << std::endl;        exit(-1);    }cerr << "image read sucessfully" << endl;/*Mat inputBlob = blobFromImage(img, 1, Size(224, 224),Scalar(104, 117, 123)); */ //构造blob,为传入网络做准备,图片不能直接进入网络Mat inputBlob = blobFromImage(img, 1, Size(227, 227));    Mat prob;    cv::TickMeter t;    for (int i = 0; i < 10; i++)    {        CV_TRACE_REGION("forward");        //将构建的blob传入网络data层        net.setInput(inputBlob,"data");         //计时        t.start();        //前向预测        prob = net.forward("prob");          //停止计时        t.stop();    }    int classId;    double classProb;//找出最高的概率ID存储在classId,对应的标签在classProb中    getMaxClass(prob, &classId, &classProb);    //打印出结果    std::vector<String> classNames = readClassNames();    std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;    std::cout << "Probability: " << classProb * 100 << "%" << std::endl;    //打印出花费时间    std::cout << "Time: " << (double)t.getTimeMilli() / t.getCounter() << " ms (average from " << t.getCounter() << " iterations)" << std::endl;//便于观察结果waitKey(0);    return 0;} 
4、运行程序得到下面结果:


OK,输出上述信息,就说明你的模型在opencv中调用没问题了,接下来就是深入学习网络结构,理解各种参数,学习各种框架,从中总结。希望大家戒骄戒躁,都能达到自己的目标,谢谢大家。

阅读全文
2 0
原创粉丝点击