prim算法 优化前O(n²) 优化后O(n-k)

来源:互联网 发布:python 短文本相似度 编辑:程序博客网 时间:2024/05/17 04:29

1.算法思想
图采用邻接矩阵存储,贪心找到目前情况下能连上的权值最小的边的另一端点,加入之,直到所有的顶点加入完毕。
2.算法实现步骤
设图G=(V,E),其生成树的顶点集合为U。
(1)把v0放入U。
(2)在所有u∈U,v∈V-U的边(u,v)∈E中找一条最小权值的边,加入生成树。
(3)把(2)找到的边的v加入U集合。如果U集合已有n个元素,则结束,否则继续执行(2)。
最后得到最小生成树U=

#include<cstdio>#include<cstring>using namespace std;#define vmax 200int w[vmax][vmax],i,j,k,v,e;void prim(int v0){    bool flag[vmax];    int min,nextk,prevk;    memset(flag,false,sizeof(flag));    flag[v0]=true;    for (i=1;i<=v-1;i++)      {        min=0x7fffffff;        for (k=1;k<=v;k++)          if (flag[k])            for (j=1;j<=v;j++)              if (!flag[j] && w[k][j]<min && w[k][j]!=0)                {                  min=w[k][j];                  nextk=j;                  prevk=k;                }        if (min!=0x7fffffff)          {            flag[nextk]=true;            printf("%d %d %d",prevk,nextk,min);          }      }}int main(){    memset(w,0,sizeof(w));    scanf("%d %d",&v,&e);    for (k=1;k<=e;k++)      {        scanf("%d %d",&i,&j);        scanf("%d",&w[i][j]);        w[j][i]=w[i][j];      }    prim(1);    return 0;}

3.算法的关键与优化
我们很容易就可以发现prim算法的关键:每次如何从生成树T到T外的所有边中,找出一条最小边。例如,在第k次前,生成树T中已有k个顶点和(k-1)条边,此时,T到T外的所有边数为k*(n-k),当然,包括没有边的两顶点我们记权值为“无穷大”的边在内,从如此多的边中查找最短边,时间复杂度为O(k(n-k)),显然无法满足我们的期望。
我们来看O(n-k)的方法:假定在进行第k次前已经保留着从T中到T外的每一个顶点(共n-k个)的各一条最短边,在进行第k次时,首先从这(n-k)条最短边中,找出一条最最短边(它就是从T到T外的最短边),假设为(vi,vj),此步需要进行(n-k)次比较;然后把边(vi,vj)和顶点vj并入T中的边集TE和顶点集U中,此时,T外只有n-(k+1)个顶点,对于其中的每个顶点vt,若(vj,vt)边上的权值小于原来保存的从T中到vt的最短边的权值,则用(v,vt)修改之,否则,保持原最小边不变。这样就把第k次后T中到T外的每一个顶点vt的各一条最短边都保留下来了,为第(k+1)次做好了准备。这样,prim的总时间复杂度为O(n²)。
【样例输入】
6 10
1 2 10
1 5 19
1 6 21
2 3 5
2 4 6
2 6 11
3 4 6
4 5 18
4 6 14
5 6 33
【样例输出】
50
优化后:

#include<cstdio>using namespace std;#define MXN 1000int map[MXN][MXN],cost[MXN],visit[MXN],i,j,n,m,x,y,v;int prim(){    int i,j,min,mini,ans;    ans=0;    for (i=1;i<=n;i++)      {        visit[i]=false;        cost[i]=0x7fffffff;      }    for (i=2;i<=n;i++)      if (map[1][i]!=0)        cost[i]=map[1][i];    visit[1]=true;    for (i=1;i<=n-1;i++)      {        min=0x7fffffff;        for (j=1;j<=n;j++)          if (!visit[j] && cost[j]<min)            {              min=cost[j];              mini=j;            }        visit[mini]=true;        ans+=min;        for (j=1;j<=n;j++)           if (!visit[j] && map[mini][j]>0 && map[mini][j]<cost[j])            cost[j]=map[mini][j];      }    return ans;}int main(){    scanf("%d %d",&n,&m);    for (i=1;i<=m;i++)      {        scanf("%d %d %d",&x,&y,&v);        if (map[x][y]==0 || map[x][y]>v)          {            map[x][y]=v;            map[y][x]=v;          }      }    printf("%d",prim());    return 0;}
0 0