[JZOJ4850]记忆的轮廓

来源:互联网 发布:淘宝衣服搭配在哪里 编辑:程序博客网 时间:2024/05/16 02:07

题目大意

原题意挺复杂的,我就尽我能力写简化一点吧……
给定一个有m个点的树形结构(1为根),其中保证1n按照编号顺序形成一条链。
然后你要在这棵树上推Gal1号点走到n号点,你走动的规则是从当前点等概率随机选择一个儿子走下去。如果你走进了错误的子树肯定走不到点n嘛,因此我们可以设置最多p个存档点,每当经过一个设置的存档点,你的当前存档点就更新为它。如果走到了一个不是n的叶子,你下一步就可以回到当前存档点。存档点必须设置在1n的链上(否则你就会无限循环)。点1和点n必须设置存档点。
其实推过Gal或者类似游戏或者大概了解这类游戏的很容易理解啦~
现在由你来选择设置存档点的位置,最小化1走到n的期望步数。
本题T组数据。

50pn700,m1500,T5
保证每个编号属于[1,n)的节点至少有两个儿子,至多有三个儿子。


题目分析

栋栋看完Re0之后出题好题。

70%

按照惯例来一波部分分:50pn500
我们将链1...n称为主链。
back(x)表示非主链节点x向下走走回当前存档点的期望步数:

back(x)=1+yson(x)back(y)|son(x)|

s(x)表示主链点x作为存档点,走入所有非主链儿子子树然后回来的期望步数:
s(x)=yson(x)x>nback(y)+1|son(x)|

gox,y表示从主链点x走到主链点y的期望步数:
gox,y=gox,y1+s(y1)+gox,y(|son(y1)|1)|son(y1)|+1|son(y1)|=gox,y1|son(y1)|+s(y1)+|son(y1)|

fi,j表示当前放了i个存档点,最后一个存档点为j,走到n的期望步数。
fi,j=mink=j+1n{fi+1,k+goj,k}

时间复杂度是O(n2p)的。当然想要拿到70分还是要蜜汁卡常。

100%

Algorithm 1

观察goj,k,如果我们将j左移一位,它的增量是多少呢?

goj1,kgoj,k=(goj1,k1|son(k1)|+s(k1)+|son(k1)|+1)(goj,k1|son(k1)|+s(k1)+|son(k1)|+1)=(goj1,k1goj,k1)|son(k1)|=(goj1,k2goj,k2)|son(k1)||son(k2)|...=goj1,ji=jk1|son(i)|=(|son(j1)|+s(j1))i=jk1|son(i)|

可以发现k<l都有goj1,kgoj,k<goj1,lgoj,l
这意味着什么呢?回忆dp式子:
fi,j=mink=j+1n{fi+1,k+goj,k}

显然i这一维可以作为最外层枚举。我们忽略i这一维,看成一个一维的dp
fj=mink=j+1n{gk+goj,k}

gk+goj,k看成一个关于j的函数hk(j),显然hk是一个单调减的函数。
并且k<lhk(j)的增量小于hl(j)。因此在j足够小的时候,会有hk(j)<hl(j)
我们考虑倒着枚举j,然后维护一个决策单调队列。队列内每个相邻两个函数交点依次减小。然后就是很经典的决策单调性套路了。
但是这里的hk(j)是一个很离散的函数,不能直接求出其交点,因此我们可以二分求这个交点。
然后就可以很顺利地转移了。
时间复杂度O(nplog2n)

Algorithm 2

出题人为什么要限制p的下界和儿子个数呢?
如果不做任何限制,稍有常识的人都可以看出来这里的go是一个指数级增长的函数,会导致我们的答案精度不够。
那么出题人在制造了这些限制之后,答案上界显然就会减少。
具体是多少呢?出题人通过构造一个可行解来估计:一种很平均的状况就是主链上每隔np个设置一个存档点,然后将主链上儿子个数都取上界,估计最坏情况。然后就是一个等比数列求和,因为p有下界,所以值不会很大,是个12位数,令其为L,而这还不是最优答案,只是一个可行解。因为我没有打这种算法,这里证明过程从略。
这意味着什么呢?显然go函数的增长是比2的幂要大的,因为儿子数下界就是2,因此如果两个主链上的点距离超过log2L,那么它显然超过了最优解上界L。所以f枚举转移的时候只需要枚举log2L以内的点,大约40多个就好了。
时间复杂度O(nplog2L)

Algorithm 3

更加神奇的算法3,由fanvree在考试时想出来。%%%
这个算法太神奇,我还不是很懂,就在这里口胡一下。听说思路来自2012中国国家集训队命题答辩能量棒。
首先显然我们应该尽量用完所有存档点。
我们二分一个神奇的东西c
这个c有什么用呢?我们用它来限制设置存档点个数。
具体怎么限制呢?就是我们检测当前答案是否合法,依然是对主链做dp,但是不在状态上对存档点个数做出限制,即取消第一维。但是同时,我们每次选取一个存档点,就将f加上c,即希望通过选取存档点需要更多的代价来限制我们尽可能少地选择存档点(不然会超过p)。
dp完后,我们检查最优解使用了多少个存档点,设为tot,如果tot>p那么说明我们限制力度还不够,需要扩大c;如果tot<p,那么说明我们限制过紧,需要减小c。否则tot=p,说明这就是最优解的方案,将答案减去tot×c即可。
但是这样我们有可能找不到会使得tot=pc,这时候我们就选取一个最小的使得tot>pc然后将答案将去tot×c
二分c的上界也是L左右即可。这个太神了,正确性我不会证明,有兴趣的自行查阅资料吧~
时间复杂度O(n2log2L)


代码实现

70%

#include <iostream>#include <cstring>#include <cfloat>#include <cstdio>using namespace std;typedef long double db;const db INF=DBL_MAX/3;const int N=700;const int M=1500;const int P=700;db f[P+5][N+5],go[N+5][N+5],back[M+5];int last[M+5],tov[M+5],next[M+5];bool vis[M+5];int n,m,tot,p,T;inline void insert(int x,int y){tov[++tot]=y,next[tot]=last[x],last[x]=tot;}db dfs(int x){    if (vis[x]) return back[x];    vis[x]=1,back[x]=0;    int cnt=0;    for (int i=last[x];i;i=next[i]) cnt++,back[x]+=dfs(tov[i]);    if (cnt) back[x]/=cnt;    return ++back[x];}void clearall(){    memset(vis,0,sizeof vis);    for (;tot;tot--) tov[tot]=next[tot]=0;    for (int i=1;i<=m;i++) last[i]=0;    memset(f,0,sizeof f);    memset(go,0,sizeof go);    memset(back,0,sizeof back);}int main(){    freopen("memory.in","r",stdin),freopen("memory_brute.out","w",stdout);    for (scanf("%d",&T);T--;clearall())    {        scanf("%d%d%d",&n,&m,&p);        for (int i=2;i<=n;i++) insert(i-1,i);        for (int i=1,x,y;i<=m-n;i++) scanf("%d%d",&x,&y),insert(x,y);        for (int i=n+1;i<=m;i++) if (!vis[i]) dfs(i);        for (int x=1;x<=n;x++)            for (int y=x;y<=n;y++)                if (x==y) go[x][y]=0;                else                {                    int cnt=0;                    go[x][y]=0;                    for (int i=last[y-1],v;i;i=next[i])                        if ((v=tov[i])>n) cnt++,go[x][y]+=back[v]+1;                    go[x][y]+=go[x][y-1]*(cnt+1)+1;                }        for (int i=1;i<=p;i++)            for (int j=1;j<=n;j++)                f[i][j]=INF;        for (int i=2;i<=p;i++) f[i][n]=0;        for (int i=p-1;i>=1;i--)            for (int j=n-1;j>=i;j--)                for (int k=j+1;k<=n;k++) f[i][j]=min(f[i][j],f[i+1][k]+go[j][k]);        printf("%.4lf\n",(double)f[1][1]);    }    fclose(stdin),fclose(stdout);    return 0;}

100

Algorithm 1

单调队列,最快的算法。

#include <iostream>#include <cstring>#include <cfloat>#include <cstdio>using namespace std;typedef long double db;const db INF=DBL_MAX/3;const int N=700;const int M=1500;const int P=700;db f[P+5][N+5],go[N+5][N+5],back[M+5];int last[M+5],tov[M+5],next[M+5];int que[N+5],pt[N+5],head,tail;bool vis[M+5];int n,m,tot,p,T;inline void insert(int x,int y){tov[++tot]=y,next[tot]=last[x],last[x]=tot;}db dfs(int x){    if (vis[x]) return back[x];    vis[x]=1,back[x]=0;    int cnt=0;    for (int i=last[x];i;i=next[i]) cnt++,back[x]+=dfs(tov[i]);    if (cnt) back[x]/=cnt;    return ++back[x];}void clearall(){    memset(vis,0,sizeof vis);    for (;tot;tot--) tov[tot]=next[tot]=0;    for (int i=1;i<=m;i++) last[i]=0;    memset(f,0,sizeof f);    memset(go,0,sizeof go);    memset(back,0,sizeof back);}db calc(int x,int y,int z){return f[x][z]+go[y][z];}int getp(int x,int p,int q){    int ret=0,l=1,r=q,mid;    while (l<=r)    {        mid=l+r>>1;        if (calc(x,mid,p)>=calc(x,mid,q)) l=(ret=mid)+1;        else r=mid-1;    }    return ret;}int main(){    freopen("memory.in","r",stdin),freopen("memory.out","w",stdout);    for (scanf("%d",&T);T--;clearall())    {        scanf("%d%d%d",&n,&m,&p);        for (int i=2;i<=n;i++) insert(i-1,i);        for (int i=1,x,y;i<=m-n;i++) scanf("%d%d",&x,&y),insert(x,y);        for (int i=n+1;i<=m;i++) if (!vis[i]) dfs(i);        for (int x=1;x<=n;x++)            for (int y=x;y<=n;y++)                if (x==y) go[x][y]=0;                else                {                    int cnt=0;                    go[x][y]=0;                    for (int i=last[y-1],v;i;i=next[i])                        if ((v=tov[i])>n) cnt++,go[x][y]+=back[v]+1;                    go[x][y]+=go[x][y-1]*(cnt+1)+1;                }        f[0][n]=f[1][n]=INF;        for (int i=2;i<=p;i++) f[i][n]=0;        for (int i=1;i<n;i++) f[p][i]=INF;        for (int i=p-1;i>=1;i--)        {            head=1,tail=0;            que[++tail]=n;            for (int j=n-1;j>=i;j--)            {                while (head!=tail&&pt[head+1]>=j) head++;                f[i][j]=calc(i+1,j,que[head]);                while (head!=tail&&pt[tail]<=getp(i+1,que[tail],j)) tail--;                que[++tail]=j,pt[tail]=getp(i+1,que[tail-1],que[tail]);            }        }        printf("%.4lf\n",(double)f[1][1]);    }    fclose(stdin),fclose(stdout);    return 0;}

Algorithm 2

上界优化法。随便找个人(@jasonvictoryan)的贴上来。

#include<cstdio> #include<cstring>#include<iostream>#include<algorithm>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fd(i,a,b) for(int i=a;i>=b;i--)#define maxn 705#define maxm 1505#define db double#define mem(a,b) memset(a,b,sizeof(a))#define min(a,b) (((a) < (b)) ? a : b)#define max(a,b) (((a) > (b)) ? a : b)using namespace std;db f[maxn][maxn],g[maxm],s[maxm];int n,m,p,T;int head[maxm],t[maxm],next[maxm],sum;int d[maxm];int tot[maxm];db a[maxn][maxn];void insert(int x,int y){    t[++sum]=y;    next[sum]=head[x];    head[x]=sum;}void dfs(int x){    d[++d[0]]=x;    for(int tmp=head[x];tmp;tmp=next[tmp]) dfs(t[tmp]);}int main(){    freopen("memory.in","r",stdin);    freopen("memory.out","w",stdout);    scanf("%d",&T);    while (T--) {        mem(f,80);        mem(head,0);        sum=0;        d[0]=0;        mem(tot,0);        mem(g,0);        mem(s,0);        ///        scanf("%d%d%d",&n,&m,&p);        fo(i,1,n-1) insert(i,i+1),tot[i]++;        fo(i,1,m-n) {            int x,y;            scanf("%d%d",&x,&y);            insert(x,y);            tot[x]++;        }        dfs(1);        fd(i,m,1) {            int w=d[i];            if (w>n) {                if (head[w]==0) g[w]=1;                else {                    for(int tmp=head[w];tmp;tmp=next[tmp])                        g[w]=g[w]+g[t[tmp]]/tot[w];                    g[w]++;                }            }            else {                if (w==n) continue;                for(int tmp=head[w];tmp;tmp=next[tmp]) {                    if (t[tmp]==w+1) continue;                    s[w]+=g[t[tmp]];                }            }        }        fo(i,1,n) {            a[i][i]=0;            fo(j,i+1,n)                a[i][j]=a[i][j-1] * tot[j-1]+tot[j-1]+s[j-1];        }        f[1][1]=0;        fo(i,1,n) {            fo(j,1,i) {                if (f[i][j]>1e12) continue;                fo(k,i+1,min(i+40,n)) f[k][j+1]=min(f[k][j+1],f[i][j]+a[i][k]);            }        }        printf("%.4lf\n",f[n][p]);    }    return 0;}

Algorithm 3

fanvree的方法&程序。

#include<cstdio>#include<cstring>#include<algorithm>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fd(i,a,b) for(int i=a;i>=b;i--)using namespace std;const int N=1600;const long double eps=0.0000001;const long long inf=10000000000000000ll;int son[N],n,m,p,fa[N];long double f[N],w[N][N],_k[N],_b[N];int pre[N],tmp[N];int check(long double mid){    fo(i,1,n) f[i]=inf,pre[i]=0;    f[1]=0;    fo(i,1,n-1)        fo(j,i+1,n)        if (f[i]+w[i][j]+mid<f[j]) f[j]=f[i]+w[i][j]+mid,pre[j]=i;    int num=0;    for(int now=n;now;now=pre[now]) num++;    return num;}int main(){    freopen("memory.in","r",stdin);freopen("memory.out","w",stdout);    int T;    scanf("%d",&T);    while (T--)    {        scanf("%d%d%d",&n,&m,&p);        memset(son,0,sizeof son);        memset(_k,0,sizeof _k);        memset(_b,0,sizeof _b);        fo(i,1,n) son[i]++;        fo(i,1,m-n)        {            int x,y;            scanf("%d%d",&x,&y);            son[x]++;            fa[y]=x;        }        fd(i,m,n+1)        {            if (son[i]==0) _b[i]=0;            _b[i]++;            _b[fa[i]]+=_b[i]*1.0/son[fa[i]];        }        fd(i,n,1)        {            long double b=0,k=0;            fd(j,i-1,1)            {                long double kk=0,bb=0,p=1.0/son[j];                kk=k*p+(1-p);                bb=p*b+_b[j]+1;                k=kk;                b=bb;                w[j][i]=b/(1-k);                if (w[j][i]>inf) w[j][i]=inf;            }        }        long double l=0,r=0;        r=10000000ll;        long double ans=inf,ans1=inf;        while (l+eps<=r)        {            long double mid=(l+r)/2;            int num=check(mid);            if (num<=p)            {                long double sum=0;                for(int now=n;now!=1;now=pre[now]) sum+=w[pre[now]][now];                if (ans==inf || ans<sum+(num-p)*mid) ans=sum+(num-p)*mid;                if (num==p)                 {                    ans=sum;                    break;                };                r=mid-eps;            } else l=mid+eps;        }        printf("%.4f\n",(double)ans);    }}
0 0