km算法(求二分图带权的最大匹配)

来源:互联网 发布:网络高清图片大全 编辑:程序博客网 时间:2024/05/29 19:23

1,如果二分图不是完全二分图,我们通过添加无用路径(最大匹配中,路径权值为0)和顶点使之成为完全二分图;

2,使用KM算法求解,KM算法核心需要理解feasible vertex labeling和equality subgraph概念,在equality subgraph中寻找最大匹配(采用匈牙利算法),如果最大匹配正好为完全匹配,根据KM理论,这个完全匹配就是带权值的最大匹配;如果在当前的equality subgraph获取的最大匹配不是完全匹配,我们通过KM算法中提供的修改label方法,增加新的y点,得到新的equality subgraph,再继续在新的subgraph寻找最大匹配,如此循环。

代码如下:

[cpp] view plaincopy
  1. #include <cstdio>  
  2. #include <memory.h>  
  3. #include <algorithm>    // 使用其中的 min 函数  
  4. using namespace std;  
  5.   
  6. const int MAX = 1024;  
  7.   
  8. int n; // X 的大小  
  9. int weight[MAX][MAX]; // X 到 Y 的映射(权重)  
  10. int lx[MAX], ly[MAX]; // 标号  
  11. bool sx[MAX], sy[MAX]; // 是否被搜索过  
  12. int match[MAX]; // Y(i) 与 X(match [i]) 匹配  
  13.   
  14. // 初始化权重  
  15. void init(int size);  
  16. // 从 X(u) 寻找增广道路,找到则返回 true  
  17. bool path(int u);  
  18. // 参数 maxsum 为 true ,返回最大权匹配,否则最小权匹配  
  19. int bestmatch(bool maxsum = true);  
  20.   
  21. void init(int size)  
  22. {  
  23.     // 根据实际情况,添加代码以初始化  
  24.     n = size;  
  25.     for (int i = 0; i < n; i++)  
  26.         for (int j = 0; j < n; j++)  
  27.             scanf("%d", &weight[i][j]);  
  28. }  
  29. /* 
  30.  * 和二分图的思路类似,在子图中寻找增广路径 
  31.  */  
  32. bool path(int u)  
  33. {  
  34.     sx[u] = true;  
  35.     for (int v = 0; v < n; v++)  
  36.         if (!sy[v] && lx[u] + ly[v] == weight[u][v])  
  37.         {  
  38.             sy[v] = true;  
  39.             if (match[v] == -1 || path(match[v]))  
  40.             {  
  41.                 match[v] = u;  
  42.                 return true;  
  43.             }  
  44.         }  
  45.     return false;  
  46. }  
  47.   
  48. int bestmatch(bool maxsum)  
  49. {  
  50.     int i, j;  
  51.     if (!maxsum)  
  52.     {  
  53.         for (i = 0; i < n; i++)  
  54.             for (j = 0; j < n; j++)  
  55.                 weight[i][j] = -weight[i][j];  
  56.     }  
  57.   
  58.     // 初始化标号  
  59.     for (i = 0; i < n; i++)  
  60.     {  
  61.         lx[i] = -0x1FFFFFFF;  
  62.         ly[i] = 0;  
  63.         for (j = 0; j < n; j++)  
  64.             if (lx[i] < weight[i][j])  
  65.                 lx[i] = weight[i][j];  
  66.     }  
  67.   
  68.     memset(match, -1, sizeof(match));  
  69.     for (int u = 0; u < n; u++)  
  70.         while (1)  
  71.         {  
  72.             memset(sx, 0, sizeof(sx));  
  73.             memset(sy, 0, sizeof(sy));  
  74.             if (path(u))    //一直寻找增广路径,直到子图中没有增广路径,我们通过修改label来增加新的点,增加的点必为y  
  75.                 break;  
  76.   
  77.             // 修改标号  
  78.             int dx = 0x7FFFFFFF;  
  79.             for (i = 0; i < n; i++)  
  80.                 if (sx[i])  
  81.                     for (j = 0; j < n; j++)  
  82.                         if (!sy[j])  
  83.                             dx = min(lx[i] + ly[j] - weight[i][j], dx); //找到松弛变量最小的点  
  84.             for (i = 0; i < n; i++)  
  85.             {  
  86.                 if (sx[i])  
  87.                     lx[i] -= dx;  
  88.                 if (sy[i])  
  89.                     ly[i] += dx;  
  90.             }  
  91.         }  
  92.   
  93.     int sum = 0;  
  94.     for (i = 0; i < n; i++)  
  95.         sum += weight[match[i]][i];  
  96.   
  97.     if (!maxsum)  
  98.     {  
  99.         sum = -sum;  
  100.         for (i = 0; i < n; i++)  
  101.             for (j = 0; j < n; j++)  
  102.                 weight[i][j] = -weight[i][j]; // 如果需要保持 weight [ ] [ ] 原来的值,这里需要将其还原  
  103.     }  
  104.     return sum;  
  105. }  
  106.   
  107. int main()  
  108. {  
  109.     int n;  
  110.     scanf("%d", &n);  
  111.     init(n);  
  112.     int cost = bestmatch(true);  
  113.   
  114.     printf("%d ", cost);  
  115.     for (int i = 0; i < n; i++)  
  116.     {  
  117.         printf("Y %d -> X %d ", i, match[i]);  
  118.     }  
  119.   
  120.     return 0;  
  121. }  

附带poj2195解法:

[cpp] view plaincopy
  1. /* 
  2.  * poj2195.cpp 
  3.  * 
  4.  *  Created on: 2012-5-9 
  5.  *      Author: ict 
  6.  */  
  7.   
  8. #include <cstdio>  
  9. #include <cstdlib>  
  10. #include <string.h>  
  11. #include <algorithm>  
  12. #include <cmath>  
  13. using namespace std;  
  14. #define MAX 200  
  15. typedef struct GRID  
  16. {  
  17.     int x;  
  18.     int y;  
  19. }grid, pgrid;  
  20. grid M[MAX], H[MAX];  
  21.   
  22. int n; // X 的大小  
  23. int weight[MAX][MAX]; // X 到 Y 的映射(权重)  
  24. int lx[MAX], ly[MAX]; // 标号  
  25. bool sx[MAX], sy[MAX]; // 是否被搜索过  
  26. int match[MAX]; // Y(i) 与 X(match [i]) 匹配  
  27.   
  28. /* 
  29.  * 和二分图的思路类似,在子图中寻找增广路径 
  30.  */  
  31. bool path(int u)  
  32. {  
  33.     sx[u] = true;  
  34.     for (int v = 0; v < n; v++)  
  35.         if (!sy[v] && lx[u] + ly[v] == weight[u][v])  
  36.         {  
  37.             sy[v] = true;  
  38.             if (match[v] == -1 || path(match[v]))  
  39.             {  
  40.                 match[v] = u;  
  41.                 return true;  
  42.             }  
  43.         }  
  44.     return false;  
  45. }  
  46.   
  47. int bestmatch(bool maxsum)  
  48. {  
  49.     int i, j;  
  50.     if (!maxsum)  
  51.     {  
  52.         for (i = 0; i < n; i++)  
  53.             for (j = 0; j < n; j++)  
  54.                 weight[i][j] = -weight[i][j];  
  55.     }  
  56.   
  57.     // 初始化标号  
  58.     for (i = 0; i < n; i++)  
  59.     {  
  60.         lx[i] = -0x1FFFFFFF;  
  61.         ly[i] = 0;  
  62.         for (j = 0; j < n; j++)  
  63.             if (lx[i] < weight[i][j])  
  64.                 lx[i] = weight[i][j];  
  65.     }  
  66.   
  67.     memset(match, -1, sizeof(match));  
  68.     for (int u = 0; u < n; u++)  
  69.         while (1)  
  70.         {  
  71.             memset(sx, 0, sizeof(sx));  
  72.             memset(sy, 0, sizeof(sy));  
  73.             if (path(u))    //一直寻找增广路径,直到子图中没有增广路径,我们通过修改label来增加新的点,增加的点必为y  
  74.                 break;  
  75.   
  76.             // 修改标号  
  77.             int dx = 0x7FFFFFFF;  
  78.             for (i = 0; i < n; i++)  
  79.                 if (sx[i])  
  80.                     for (j = 0; j < n; j++)  
  81.                         if (!sy[j])  
  82.                             dx = min(lx[i] + ly[j] - weight[i][j], dx); //找到松弛变量最小的点  
  83.             for (i = 0; i < n; i++)  
  84.             {  
  85.                 if (sx[i])  
  86.                     lx[i] -= dx;  
  87.                 if (sy[i])  
  88.                     ly[i] += dx;  
  89.             }  
  90.         }  
  91.   
  92.     int sum = 0;  
  93.     for (i = 0; i < n; i++)  
  94.         sum += weight[match[i]][i];  
  95.   
  96.     if (!maxsum)  
  97.     {  
  98.         sum = -sum;  
  99.         for (i = 0; i < n; i++)  
  100.             for (j = 0; j < n; j++)  
  101.                 weight[i][j] = -weight[i][j]; // 如果需要保持 weight [ ] [ ] 原来的值,这里需要将其还原  
  102.     }  
  103.     return sum;  
  104. }  
  105.   
  106.   
  107.   
  108. int main()  
  109. {  
  110.     int row, col;  
  111.     int i, j;  
  112.     int ch;  
  113.     int mCount, hCount;  
  114.   
  115.     while(1)  
  116.     {  
  117.         scanf("%d %d", &row, &col);  
  118.         getchar();  
  119.         if(row == 0 && col == 0)  
  120.             break;  
  121.         mCount = 0;  
  122.         hCount = 0;  
  123.         for(i = 0; i < row ; i++)  
  124.         {  
  125.             for(j = 0; j < col; j++)  
  126.             {  
  127.                 ch = getchar();  
  128.                 if(ch == 'm')  
  129.                 {  
  130.                     M[mCount].x = i;  
  131.                     M[mCount].y = j;  
  132.                     mCount++;  
  133.                 }  
  134.                 else  
  135.                     if(ch == 'H')  
  136.                     {  
  137.                         H[hCount].x = i;  
  138.                         H[hCount].y = j;  
  139.                         hCount++;  
  140.                     }  
  141.             }  
  142.             getchar();  
  143.         }  
  144.   
  145.         n = mCount;  
  146.   
  147.         for(i = 0 ; i < n; i++)  
  148.             for(j = 0; j < n; j++)  
  149.             {  
  150.                 weight[i][j] = abs(M[i].x - H[j].x) + abs(M[i].y - H[j].y);  
  151.             }  
  152.   
  153.         printf("%d\n", bestmatch(false));  
  154.     }  
  155.   
  156.     return 0;  
  157.   
  158. }  
0 0
原创粉丝点击