BZOJ2286:消耗战(虚树,树形dp)

来源:互联网 发布:勇士队球员数据 编辑:程序博客网 时间:2024/06/05 04:06

今天本tu生日,学个新算法庆祝一下。学了虚树,碰到每次询问给你一些点点的树就不虚了…

对于一棵树,我们可以在上面用我们学过的算法为所欲为。假设题目有多个询问,每个询问给出了一些点,那我们可以把这些点和及有关系的点拉出来,合并点和边的信息,构出虚树,在虚树上继续为所欲为。

对于与询问有关系的点,就是这些点两两之间的lca。把这些点按欧拉序排序,lca就是区间深度最小的点,所以k个点两两的lca最多k-1个。把这些点拉出来重构一棵树。

具体实现:将点按dfs序排序,用一个栈,维护按深度排序的一条虚树的链,对于新来的点x,找到x与栈顶结点的lca,将lca以下的链的边连好,全部出栈,lca进栈,就搞定了。

题面
题意:给出一棵树,每次给出k个点,让它们全部与1号结点不连通,所删除边的最小边权。

构出虚树后,考虑树形dp即可。对于点x,f[x]为原树x到根路径上的最小边权,g[x]表示虚树中,是x子树内的点都不与1连通的最小代价。
若x是给出的点,g[x]=f[x],否则g[x]=min(f[x],yxg[y])

这里用的树剖lca

#include <iostream>#include <fstream>#include <algorithm>#include <cmath>#include <ctime>#include <cstdio>#include <cstdlib>#include <cstring>using namespace std;#define mmst(a, b) memset(a, b, sizeof(a))#define mmcp(a, b) memcpy(a, b, sizeof(b))typedef long long LL;const int N=1001000;const LL oo=1e18;void read(int &hy){    hy=0;    char cc=getchar();    while(cc<'0'||cc>'9')    cc=getchar();    while(cc>='0'&&cc<='9')    {        hy=(hy<<3)+(hy<<1)+cc-'0';        cc=getchar();    }}int n,m,k,a[N];int val[N],to[N],nex[N],head[N],cnt;int son[N],siz[N],dep[N],fa[N],top[N],tim[N],times;LL mn[N],f[N];int st[N];bool vis[N];void add(int u,int v,int w){    to[++cnt]=v;    val[cnt]=w;    nex[cnt]=head[u];    head[u]=cnt;}void dfs(int x){    siz[x]=1;    tim[x]=++times;    for(int h=head[x];h;h=nex[h])    if(to[h]!=fa[x])    {        fa[to[h]]=x;        dep[to[h]]=dep[x]+1;        mn[to[h]]=min(mn[x],(LL)val[h]);        dfs(to[h]);        siz[x]+=siz[to[h]];        if(siz[to[h]]>siz[son[x]])        son[x]=to[h];    }}void dfs2(int x,int tp){    top[x]=tp;    for(int h=head[x];h;h=nex[h])    if(to[h]!=fa[x]&&to[h]!=son[x])    dfs2(to[h],to[h]);    if(son[x])    dfs2(son[x],tp);}int lca(int x,int y){    while(top[x]!=top[y])    {        if(dep[top[x]]<dep[top[y]])        swap(x,y);        x=fa[top[x]];    }    return dep[x]<dep[y] ? x : y;}bool cmp(int x,int y){    return tim[x]<tim[y];}LL dp(int x){    LL t=0;    for(int h=head[x];h;h=nex[h])    t+=dp(to[h]);    if(vis[x])    f[x]=mn[x];    else    f[x]=min(mn[x],t);    head[x]=0;    return f[x];}void work(){    cnt=0;    read(k);    for(int i=1;i<=k;i++)    read(a[i]),vis[a[i]]=1;    sort(a+1,a+k+1,cmp);    int t=0;    for(int i=1;i<=k;i++)    {        if(!t)        {            st[++t]=a[i];            continue;        }        int lc=lca(st[t],a[i]);        while(lc!=st[t])        {            if(tim[lc]>=tim[st[t-1]])            {                add(lc,st[t--],0);                if(st[t]!=lc)                st[++t]=lc;                break;            }            else            add(st[t-1],st[t],0),t--;        }        st[++t]=a[i];    }    while(t!=1)    add(st[t-1],st[t],0),t--;    dp(st[1]);    for(int i=1;i<=k;i++)    vis[a[i]]=0;    printf("%lld\n",f[st[1]]);}int main(){    cin>>n;    for(int i=1;i<n;i++)    {        int u,v,w;        read(u);        read(v);        read(w);        add(u,v,w);        add(v,u,w);    }    mn[1]=oo;    dfs(1);    dfs2(1,1);    mmst(head,0);    cin>>m;    while(m--)    work();    return 0;}

这里写图片描述