hdu4126树形dp + 错误分析

来源:互联网 发布:中国网络自由 编辑:程序博客网 时间:2024/06/06 21:47

先贴第一次写的代码

#include <cstdio>#include <iostream>#include <vector>#include <cstring>using namespace std;//结构体//常量const int N_MAX = 3005;const int M_MAX = 9000005;const int INF = 1000000000;//变量int N, M, Q; //N个村庄M条路vector<int> V[N_MAX]; //邻接表来表示图vector<int> D[N_MAX]; //邻接表表示树long long cost[N_MAX][N_MAX]; // 两点间的距离int dist[N_MAX]; // 表示n这个节点距离根节点中距离最小的int Prev[N_MAX];int used[N_MAX];long long mincost[N_MAX];//函数long long prim(){    long long sum = 0;    memset(Prev,-1,sizeof(Prev));    fill(mincost,mincost+N,INF);    fill(dist,dist+N, INF);    fill(used,used+N, 0);    mincost[0] = 0;    while(true)    {        int v = -1;        for(int i=0; i<N; i++)        {            if(!used[i]&&(v==-1||mincost[i]<mincost[v]))v = i;        }        if(v == -1)break;        sum += mincost[v];        used[v] = 1;        if(Prev[v]!=-1)            D[Prev[v]].push_back(v);        for(int i=0; i<N; i++)        {            if(!used[i] && mincost[i]>cost[v][i])            {                mincost[i] = cost[v][i];                Prev[i] = v;            }        }    }    return sum;}int dfs2(int m, int n)//计算m这课树与n这课树上的最短距离{    int MIN = cost[m][n];    for(int i=0; i<D[m].size(); i++)    {        if(D[m][i] == n)continue;        MIN = min(MIN, dfs2(D[m][i], n));    }    return MIN;}void dfs(int n) // 访问第n个节点{    long long temp = cost[n][Prev[n]];    cost[n][Prev[n]] = cost[Prev[n]][n] = INF;    int MIN = dfs2(0,n);        //求n这个节点到0这颗树上的最短距离    cost[n][Prev[n]] = cost[Prev[n]][n] = temp;    for(int i=0; i<D[n].size(); i++)    {        int v = D[n][i];        dfs(v);        MIN = min(MIN,dist[v]);    }    dist[n] = MIN;}int main(){    freopen("1.txt","r",stdin);    freopen("mytext.txt","w",stdout);    long long sum;//最小路径    while(scanf("%d%d", &N,&M)!=EOF)    {        double ans = 0;        for(int i=0; i<N; i++)fill(cost[i],cost[i]+N, INF);        for(int i=0; i<N; i++){D[i].clear();V[i].clear();}        int x, y, c;        if(N==0&&M==0)            break;        for(int i=0; i<M; i++)        {            scanf("%d%d%d", &x, &y, &c);            V[x].push_back(y);            V[y].push_back(x);            cost[x][y] = cost[y][x] = c;        }        sum = prim();        cout << "sum:" << sum << endl;        dfs(0);        scanf("%d", &Q);        for(int i=0; i<Q; i++)        {            scanf("%d%d%d", &x, &y, &c);            if(Prev[x] == y)                ans += sum*1.0 - cost[x][y] + min(dist[x],c);            else if(Prev[y] == x)                ans += sum*1.0 - cost[x][y] + min(dist[y],c);            else                ans += sum*1.0;        }            printf("%.4lf\n",ans/Q);    }    return 0;}

分析错误原因:

自己写的代码的思路是这样 的,先用prim算出最小生成树,并且在该过程,将这个最小生成树的那颗树存储起来


然后定义dist[i] 为以i这个为根节点的树距离,以0为节点的树排除以i为节点的树,之间的最大距离,当然在写的时候,要把i和i的父亲之间的距离设置为最大值。


然后在遍历之后,再算i的子节点中距离以0为根节点的树的最小距离,并用临时变量MIN来储存,最后和刚开始计算的dist[i]比较,如果MIN小于dist[i]则将其赋值给dist[i]。



最终就计算了所有的dist[i],然而提交后WA,WAWAWAWAWA,自己测试了一天的数据也没找到错误,最终请教我们专业12级的大神帮忙,他使用了随机数生成的方法,很快就测试出了一组数据有问题,

然后我用这组数据,找到了我的问题所在。



果然是年轻啊。。



比如下图




在计算到dist[1]的时候,是1到2 的距离最小,那么计算3的时候,dist[3]和子节点dist[2]比较后也很小,所以dist[3] = dist[2],即3这颗子树上距离0这颗树上的最短距离为1到2的距离。再接着,计算dist[2]也是等于dist[1] ,这样就说不通了,因为2这颗树和0这颗树的最短距离,不可能是1到2的距离。这就是问题所在。

所以大概做了2天,测试数据测试1天,最终把问题找到了,当然有效时间不可能是这么久。因为个人比较能拖。。。这是病得治。


接着贴出大神帮忙写的随机数测试数据的代码。以便以后学习


同时里面含有  #include<bits/stdc++.h>,然后今天才知道这是万能头文件。。。

OMG

#include<bits/stdc++.h>using namespace std;//结构体//常量const int N_MAX = 3005;const int M_MAX = 9000005;const int INF = 1000000000;//变量int N, M, Q; //N个村庄M条路vector<int> V[N_MAX]; //邻接表来表示图vector<int> D[N_MAX]; //邻接表表示树long long cost[N_MAX][N_MAX]; // 两点间的距离int dist[N_MAX]; // 表示n这个节点距离根节点中距离最小的int Prev[N_MAX];int used[N_MAX];long long mincost[N_MAX];//函数long long prim(){    long long sum = 0;    memset(Prev,-1,sizeof(Prev));    fill(mincost,mincost+N,INF);    fill(dist,dist+N, INF);    fill(used,used+N, 0);    mincost[0] = 0;    while(true)    {        int v = -1;        for(int i=0; i<N; i++)        {            if(!used[i]&&(v==-1||mincost[i]<mincost[v]))v = i;        }        if(v == -1)break;        sum += mincost[v];        used[v] = 1;        if(Prev[v]!=-1)            D[Prev[v]].push_back(v);        for(int i=0; i<N; i++)        {            if(!used[i] && mincost[i]>cost[v][i])            {                mincost[i] = cost[v][i];                Prev[i] = v;            }        }    }    return sum;}int dfs2(int m, int n)//计算m这课树与n这课树上的最短距离{    int MIN = cost[m][n];    for(int i=0; i<D[m].size(); i++)    {        if(D[m][i] == n)continue;        MIN = min(MIN, dfs2(D[m][i], n));    }    return MIN;}void dfs(int n) // 访问第n个节点{    long long temp = cost[n][Prev[n]];    cost[n][Prev[n]] = cost[Prev[n]][n] = INF;    int MIN = dfs2(0,n);        //求n这个节点到0这颗树上的最短距离    cost[n][Prev[n]] = cost[Prev[n]][n] = temp;    for(int i=0; i<D[n].size(); i++)    {        int v = D[n][i];        dfs(v);        MIN = min(MIN,dist[v]);    }    dist[n] = MIN;}struct Edge{    int u,v,c;    bool operator<(Edge e2)const{        return c < e2.c;    }};int pa[N_MAX];int Find(int i){    return i==pa[i] ? i : pa[i]=Find(pa[i]);}long long klu(){    vector<Edge> es;    for(int i=0;i<N;++i)for(int j=i+1;j<N;++j){        Edge e;e.u=i;e.v=j;e.c=cost[i][j];        es.push_back(e);    }    for(int i=0;i<N;++i)pa[i]=i;    sort(es.begin(),es.end());    long long sum = 0;    for(int i=0;i<(int)es.size();++i){        if(Find(es[i].u) != Find(es[i].v)){            pa[Find(es[i].u)] = Find(es[i].v);            sum+=es[i].c;        }    }    return sum;}long long bf(int x,int y,int c){    int ori = cost[x][y];    cost[x][y]=cost[y][x] = c;    long long res = klu();    cost[x][y]=cost[y][x] = ori;    return res;}int main(){    //freopen("4126.in","r",stdin);    //freopen("mytext.txt","w",stdout);    srand(11);    long long sum;//最小路径    int T = 10000;    while(T--)    {        N = 4,M = N*(N-1)/2;        double ans = 0;        for(int i=0; i<N; i++)fill(cost[i],cost[i]+N, INF);        for(int i=0; i<N; i++){D[i].clear();V[i].clear();}        int x, y, c;        if(N==0&&M==0)            break;        /*        for(int i=0; i<M; i++)        {            scanf("%d%d%d", &x, &y, &c);            V[x].push_back(y);            V[y].push_back(x);            cost[x][y] = cost[y][x] = c;        }        */        for(int i = 0; i < N; ++i)for(int j = i+1; j < N; ++j){            cost[i][j] = cost[j][i] = rand()%10+3;            V[i].push_back(j);            V[j].push_back(i);        }        sum = prim();        //cout<<"prim  "<<sum<<endl;        dfs(0);        //scanf("%d", &Q);        //cout<<N<<" "<<M<<endl;        //for(int i=0;i<N;++i)for(int j=i+1;j<N;++j){ cout<<i<<" "<<j<<" "<<cost[i][j]<<endl; }        Q = 30;        //cout<<Q<<endl;        for(int i=0; i<Q; i++)        {            //scanf("%d%d%d", &x, &y, &c);            while(x=rand()%N, y = rand()%N, x==y);            c = rand()%3+cost[x][y]+1;            //cout<<x<<" "<<y<<" "<<c<<endl;            long long tmp1,tmp2;            if(Prev[x] == y)                tmp1= sum*1.0 - cost[x][y] + min(dist[x],c);            else if(Prev[y] == x)                tmp1= sum*1.0 - cost[x][y] + min(dist[y],c);            else                tmp1= sum*1.0;            ans += tmp1;            //bf            tmp2 = bf(x,y,c);            if(tmp1 != tmp2){                cout<<"error   "<<tmp1<<" "<<tmp2<<endl;                cout<<N<<" "<<M<<endl;                for(int i=0;i<N;++i)for(int j=i+1;j<N;++j){ cout<<i<<" "<<j<<" "<<cost[i][j]<<endl; }                cout<<x<<" "<<y<<" "<<c<<endl;                return 0;            }            //assert(tmp1 == tmp2);        }        //printf("%.4lf\n",ans/Q);    }    return 0;}




OK ,待续,正在奋斗该题。


0 0
原创粉丝点击