BZOJ2286 消耗战

来源:互联网 发布:淘宝砗磲 编辑:程序博客网 时间:2024/05/22 02:30

题目大意

有一棵n个点的树,每条边有边权。有m次询问,每次给定k个关键点,问能切断根(1号点)到所有关键点的最小代价是多少?
n<=250000,m<=500000,k<=500000

Solution

可以发现,每次询问时只有关键点和关键点之间的LCA是有用的,知道了这些点,就能计算出答案。而且可以证明,这些点的总点数小于2k,总的复杂度可以变成O(k)。
用一个栈,就能建出这棵树。代码如下:

int top=1;stack[top]=a[1];for (int i=2;i<=cnt;i++)//cnt为要建树的节点数{    while (top&&deep[stack[top]]>deep[lca(stack[top],a[i])])        top--;    if (stack[top]) add(stack[top],a[i],0);    stack[++top]=a[i];}//a为要建树的所有节点。

建树的复杂度是klogn的。
建完这棵树后,在树上跑一次DP就可以。这个DP应该很显然了。

代码

#include<cstdio>#include<algorithm>#include<cstring>using namespace std;typedef long long ll;const ll INF=1LL<<60;int head[1000010],num,ti=0,dfn[1000010],next[2000010],vet[2000010],vel[2000010],flag[1000010],stack[1000010];ll dp[1000010],dis[1000010];int fa[1000010][22],deep[1000010],a[1000010],x[1000010];bool cmp(int x,int y){    return dfn[x]<dfn[y];}void add(int u,int v,int s){    next[++num]=head[u];    head[u]=num;    vet[num]=v;    vel[num]=s;}void dfs(int u){    dfn[u]=++ti;    for (int i=head[u];i;i=next[i])    {        int v=vet[i];        if (v!=fa[u][0])        {            fa[v][0]=u,deep[v]=deep[u]+1,dis[v]=min((ll)vel[i],dis[u]);            dfs(v);        }    }}int lca(int u,int v){    if (deep[u]<deep[v]) swap(u,v);    for (int i=20;i>=0;i--)        if (fa[u][i]&&deep[fa[u][i]]>=deep[v]) u=fa[u][i];    if (u==v) return u;    for (int i=20;i>=0;i--)        if (fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];    return fa[u][0];}void DP(int u,int fa){    if (flag[u])    {        dp[u]=dis[u];        return;    }    dp[u]=dis[u];    ll sum=0;    int Flag=0;    for (int i=head[u];i;i=next[i])    {        int v=vet[i];        if (v!=fa)        {            DP(v,u);            sum+=dp[v];            Flag=1;        }    }    if (Flag) dp[u]=min(dp[u],sum);}int main(){    int n;    scanf("%d",&n);    for (int i=1;i<n;i++)    {        int u,v,c;        scanf("%d%d%d",&u,&v,&c);        add(u,v,c);        add(v,u,c);    }    deep[1]=1;    dis[1]=INF;    dfs(1);    for (int i=1;i<=20;i++)        for (int j=1;j<=n;j++)            fa[j][i]=fa[fa[j][i-1]][i-1];    int m;    memset(head,0,sizeof(head));    scanf("%d",&m);    while (m--)    {        int K;        scanf("%d",&K);        num=0;        for (int i=1;i<=K;i++)        {            scanf("%d",&a[i]);            x[i]=a[i];            flag[x[i]]=1;        }        int cnt=K;        a[++cnt]=1;        sort(a+1,a+1+cnt,cmp);        cnt=unique(a+1,a+1+cnt)-a-1;        int tmp=cnt;        for (int i=1;i<tmp;i++)            a[++cnt]=lca(a[i],a[i+1]);        sort(a+1,a+1+cnt,cmp);        cnt=unique(a+1,a+1+cnt)-a-1;        int top=1;        stack[top]=a[1];        for (int i=2;i<=cnt;i++)        {            while (top&&deep[stack[top]]>deep[lca(stack[top],a[i])])                top--;            if (stack[top]) add(stack[top],a[i],0);            stack[++top]=a[i];        }        DP(1,0);        printf("%lld\n",dp[1]);        for (int i=1;i<=K;i++) flag[x[i]]=0;        for (int i=1;i<=cnt;i++) head[a[i]]=dp[a[i]]=0;    }     return 0;}