300W数据集测试MTCNN的landmark效果代码

来源:互联网 发布:相机滤镜软件下载 编辑:程序博客网 时间:2024/06/04 18:02

300W数据集测试MTCNN的landmark效果,用提取其中afw数据集337张图片的预测关键点并写入到txt中,再用测试程序和标注landmark做对比。

处理得到的预测landmark格式如下:
1051618982(图片名)
1(landmark个数)
543 267 643 268 594 322 542 359 643 360
111076519
2
1095 624 1161 635 1125 668 1084 696 1146 706
1172 764 1238 767 1211 806 1168 830 1233 833
1130084326

程序如下:

[python] view plain copy
  1. #include "network.h"  
  2. #include "mtcnn.h"  
  3. #include <time.h>  
  4. #include <fstream>  
  5. #include <opencv2/opencv.hpp>  
  6. #pragma comment(lib, "libopenblas.dll.a")  
  7.   
  8. using namespace cv;  
  9.   
  10. std::vector<std::string> split(std::string& str, std::string& pattern);  
  11. void str2int(int &int_temp, const string &string_temp);  
  12.   
  13. int main()  
  14. {  
  15.     //因为执行目录被设置到openblas/x64下了,保证dll能正常载入,这时候图片路径就相对要提上去2级  
  16.     //Mat im = imread("../../test10.jpg");  
  17.     int i = 0;  
  18.     string img_dir = "E:/face_alignment/data/300W_test_5points/afw/";  
  19.     string name, s="_";  
  20.     ifstream infile;  
  21.     ofstream outfile;  
  22.     infile.open("E:/face_alignment/data/300W_test_5points/afw_mtcnn_test_V2_2.txt");  
  23.     outfile.open("E:/face_alignment/data/300W_test_5points/MTCNN_V2_2_test/afw_test.txt");  
  24.   
  25.     while (infile)  
  26.     {  
  27.         infile >> name;  
  28.         vector<string> result = split(name, s);  
  29.         cout << i << endl;  
  30.         i++;  
  31.         //cout << "name: " << result[0] << " s: " << result[1] << endl;  
  32.           
  33.         string shot_name = result[0];  
  34.         int decide;  
  35.         str2int(decide, result[1]);  
  36.   
  37.         if (decide == 1)  
  38.         {  
  39.             string image_name = name + ".jpg";  
  40.             outfile << shot_name  << endl;  
  41.             string image_name_dir = img_dir + image_name;  
  42.             cout << image_name_dir << endl;  
  43.   
  44.             Mat im = imread(image_name_dir);  
  45.             vector<vector<Point2f>> key_points;  
  46.             mtcnn find(im.cols, im.rows);  
  47.             vector<Rect> objs = find.detectObject(im, key_points);  
  48.             //outfile << objs.size() << endl;  
  49.             for (int i = 0; i < objs.size(); ++i)  
  50.             {  
  51.                 rectangle(im, objs[i], Scalar(0255), 2);  
  52.                 //outfile << objs[i].x << " " << objs[i].y << " " << objs[i].width << " " << objs[i].height << endl;  
  53.             }  
  54.                   
  55.   
  56.             //cout << "num of key_points: " << key_points.size() << endl;  
  57.             outfile << key_points.size() << endl;  
  58.             for (int i = 0; i < key_points.size(); i++)  
  59.             {  
  60.                 for (int j = 0; j < key_points[i].size(); j++)  
  61.                 {  
  62.                     cv::circle(im, key_points[i][j], 1, cv::Scalar(2552550), 2);  
  63.                     outfile << key_points[i][j].x << " " << key_points[i][j].y << " ";  
  64.                 }  
  65.                 outfile << endl;  
  66.             }  
  67.               
  68.             string outdir = "E:/face_alignment/data/300W_test_5points/MTCNN_V2_2_test/afw/";  
  69.             string out_image = outdir + shot_name + ".jpg";  
  70.             imwrite(out_image, im);  
  71.             //imshow("demo", im);  
  72.             //waitKey(0);  
  73.         }  
  74.   
  75.           
  76.     }  
  77.   
  78.     infile.close();  
  79.     outfile.close();  
  80.   
  81.       
  82.       
  83.     return 0;  
  84. }  
  85.   
  86.   
  87.   
  88.   
  89.   
  90. //字符串分割函数     
  91. std::vector<std::string> split(std::string& str, std::string& pattern)  
  92. {  
  93.     std::string::size_type pos;  
  94.     std::vector<std::string> result;  
  95.     //str += pattern;//扩展字符串以方便操作       
  96.     int size = str.size();  
  97.       
  98.     pos = str.find(pattern, 0);  
  99.     if (pos<size) {  
  100.         std::string s1 = str.substr(0, pos);  
  101.         std::string s2 = str.substr(pos + 1, size - 1);  
  102.         result.push_back(s1);  
  103.         result.push_back(s2);  
  104.     }  
  105.       
  106.     return result;  
  107. }  
  108.   
  109.   
  110. void str2int(int &int_temp, const string &string_temp)  
  111. {  
  112.     stringstream stream(string_temp);  
  113.     stream >> int_temp;  
  114. }  

用我们自己得到的预测txt,与作者提供的标注pts文件进行计算。  算5个landmark的欧式距离之和,除以左上角和右下角欧式距离,除以5。

[python] view plain copy
  1. #include <iostream>    
  2. #include <stdlib.h>    
  3. #include <fstream>    
  4. #include <sstream>    
  5. #include <string>    
  6. #include <vector>    
  7. #include <opencv2/opencv.hpp>      
  8. using namespace cv;  
  9. using namespace std;  
  10.   
  11. std::vector<std::string> split(std::string& str, std::string& pattern);  
  12. void str2int(int &int_temp, const string &string_temp);  
  13. float computer_error(vector<float> pts_gt, vector<vector<float>> pts_pre);  
  14.   
  15. int main()  
  16. {  
  17.     int count = 0, pos = 0;  
  18.     float acc, thread=0.1;  
  19.     string name_list, s = "_";  
  20.     string pts_dir = "E:/face_alignment/data/300W_test_5points/afw/";  
  21.     ifstream infile_list, infile_pre;  
  22.     infile_list.open("E:/face_alignment/data/300W_test_5points/afw_mtcnn_test_V1.txt");  
  23.     infile_pre.open("E:/face_alignment/data/300W_test_5points/MTCNN_V1_test/afw_test.txt");  
  24.     infile_list >> name_list;  
  25.   
  26.     while (infile_pre)  
  27.     {  
  28.         string name_pre;  
  29.         int num_pre;  
  30.         vector<vector<float> > pts_pre;  
  31.         infile_pre >> name_pre;  
  32.         infile_pre >> num_pre;  
  33.         pts_pre.resize(num_pre);  
  34.   
  35.         for (int i = 0; i < num_pre; i++)  
  36.         {  
  37.             pts_pre[i].resize(10);  
  38.         }  
  39.   
  40.         for (int j = 0; j < num_pre; j++)  
  41.         {  
  42.             infile_pre >> pts_pre[j][0] >> pts_pre[j][1] >> pts_pre[j][2] >> pts_pre[j][3] >> pts_pre[j][4] >> pts_pre[j][5] >> pts_pre[j][6] >> pts_pre[j][7] >> pts_pre[j][8] >> pts_pre[j][9];  
  43.               
  44.         }  
  45.   
  46.   
  47.         //for (int i = 0; i < num_pre; i++)  
  48.         //{  
  49.         //  for (int j = 0; j < 10; j++)  
  50.         //  {  
  51.         //      cout << pts_pre[i][j] << " ";  
  52.         //  }  
  53.         //  cout << endl;  
  54.         //}  
  55.   
  56.   
  57.         // read gt file  
  58.         while (infile_list)  
  59.         {  
  60.               
  61.             vector<string> result = split(name_list, s);  
  62.             string name_gt = result[0];  
  63.   
  64.             if (name_gt.compare(name_pre) == 0)  
  65.             {  
  66.                 count++;  
  67.                 cout << count << endl;  
  68.   
  69.                 vector<float> pts_gt;  
  70.                 pts_gt.resize(10);  
  71.                 string pts_dir_name = pts_dir + name_list + ".pts";  
  72.                 ifstream infile_pts;  
  73.                 infile_pts.open(pts_dir_name);  
  74.   
  75.                 string ss;  
  76.                 int yy;  
  77.                 infile_pts >> ss >> yy;  
  78.                 infile_pts >> ss >> yy;  
  79.                 infile_pts >> ss;  
  80.                 for (int i = 0; i < 5; i++)  
  81.                 {  
  82.                     infile_pts >> pts_gt[i*2] >> pts_gt[i*2+1];  
  83.                 }  
  84.                   
  85.                 infile_pts.close();  
  86.                   
  87.                   
  88.                 float error = computer_error(pts_gt, pts_pre);  
  89.                 error = error / 5.0;  
  90.                 if (error <= thread)  
  91.                     pos++;  
  92.                   
  93.   
  94.                 cout << error << " " << endl;  
  95.                 infile_list >> name_list;  
  96.                 //cout << name_list << endl;  
  97.   
  98.             }  
  99.   
  100.             else  
  101.                 break;  
  102.   
  103.         }     
  104.   
  105.   
  106.     }  
  107.   
  108.     acc = float(pos) / float(count);  
  109.     cout << "accury: " << acc << endl;  
  110.   
  111.     infile_list.close();  
  112.     infile_pre.close();  
  113.   
  114. }  
  115.   
  116.   
  117.   
  118. // computer alinment loss  
  119. float computer_error(vector<float> pts_gt, vector<vector<float>> pts_pre)  
  120. {  
  121.     if (pts_pre.size() == 0)  
  122.         return 10;  
  123.   
  124.     float RMSE, d_outer, align_loss;  
  125.     d_outer = sqrt((pts_gt[0] - pts_gt[8])*(pts_gt[0] - pts_gt[8]) + (pts_gt[1] - pts_gt[9])*(pts_gt[1] - pts_gt[9]));  
  126.     for (int i = 0; i < pts_pre.size(); i++)  
  127.     {  
  128.         RMSE = 0;  
  129.         for (int j = 0; j < 5; j++)  
  130.         {  
  131.             RMSE += sqrt((pts_gt[2 * j] - pts_pre[i][2 * j])*(pts_gt[2 * j] - pts_pre[i][2 * j]) + (pts_gt[2 * j + 1] - pts_pre[i][2 * j + 1])*(pts_gt[2 * j + 1] - pts_pre[i][2 * j + 1]));  
  132.         }  
  133.         RMSE = RMSE / d_outer;  
  134.   
  135.         if (i == 0)  
  136.         {  
  137.             align_loss = RMSE;  
  138.         }  
  139.         else  
  140.         {  
  141.             if (align_loss > RMSE)  
  142.                 align_loss = RMSE;  
  143.         }  
  144.     }  
  145.     return align_loss;  
  146. }  
  147.   
  148.   
  149.   
  150. //字符串分割函数     
  151. std::vector<std::string> split(std::string& str, std::string& pattern)  
  152. {  
  153.     std::string::size_type pos;  
  154.     std::vector<std::string> result;  
  155.     //str += pattern;//扩展字符串以方便操作       
  156.     int size = str.size();  
  157.   
  158.     pos = str.find(pattern, 0);  
  159.     if (pos<size) {  
  160.         std::string s1 = str.substr(0, pos);  
  161.         std::string s2 = str.substr(pos + 1, size - 1);  
  162.         result.push_back(s1);  
  163.         result.push_back(s2);  
  164.     }  
  165.   
  166.     return result;  
  167. }  
  168.   
  169.   
  170. void str2int(int &int_temp, const string &string_temp)  
  171. {  
  172.     stringstream stream(string_temp);  
  173.     stream >> int_temp;  
  174. }