[BZOJ3991][SDOI2015]寻宝游戏(dfs序+lca+set)

来源:互联网 发布:电脑发短信软件 编辑:程序博客网 时间:2024/03/29 19:34

题目描述

传送门

题解

答案其实就是将所有的点按照dfs序排序然后相邻求lca以及长度加和
奥还有第一个和最后一个求lca以及长度加和
用set维护一下。。。

代码

#include<algorithm>#include<iostream>#include<cstring>#include<cstdio>#include<cmath>#include<set>using namespace std;#define N 200005#define sz 17int n,m,dfs_clock;int tot,point[N],nxt[N],v[N];long long c[N];int deep[N],dfn[N],pt[N],f[N][sz+3],flag[N];set <int> s;long long h[N],ans;void add(int x,int y,long long z){    ++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=z;}void dfs(int x,int fa){    dfn[x]=++dfs_clock;pt[dfs_clock]=x;deep[x]=deep[fa]+1;    for (int i=1;i<sz;++i) f[x][i]=f[f[x][i-1]][i-1];    for (int i=point[x];i;i=nxt[i])        if (v[i]!=fa)        {            h[v[i]]=h[x]+c[i];            f[v[i]][0]=x;            dfs(v[i],x);        }}int lca(int x,int y){    if (deep[x]<deep[y]) swap(x,y);    int k=deep[x]-deep[y];    for (int i=0;i<sz;++i)        if ((k>>i)&1) x=f[x][i];    if (x==y) return x;    for (int i=sz-1;i>=0;--i)        if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];    return f[x][0];}long long calc(int x,int y){    int r=lca(x,y);    return h[x]+h[y]-2*h[r];}int main(){    scanf("%d%d",&n,&m);    for (int i=1;i<n;++i)    {        int x,y;long long z;        scanf("%d%d%lld",&x,&y,&z);        add(x,y,z),add(y,x,z);    }    dfs(1,0);    s.insert(0);s.insert(n+1);    for (int i=1;i<=m;++i)    {        int x;scanf("%d",&x);        if (!flag[x])        {            s.insert(dfn[x]);            int pre=*--s.find(dfn[x]);            int nxt=*++s.find(dfn[x]);            if (pre>=1) ans+=calc(pt[pre],x);            if (nxt<=n) ans+=calc(pt[nxt],x);            if (pre>=1&&nxt<=n) ans-=calc(pt[pre],pt[nxt]);        }        else        {            int pre=*--s.find(dfn[x]);            int nxt=*++s.find(dfn[x]);            if (pre>=1) ans-=calc(pt[pre],x);            if (nxt<=n) ans-=calc(pt[nxt],x);            if (pre>=1&&nxt<=n) ans+=calc(pt[pre],pt[nxt]);            s.erase(dfn[x]);        }        int fir=*++s.find(0);        int last=*--s.find(n+1);        long long add;        if (fir<1||last>n) add=0;        else        {            int r=lca(pt[fir],pt[last]);            add=h[pt[fir]]+h[pt[last]]-2*h[r];        }        printf("%lld\n",ans+add);        flag[x]^=1;    }}
0 0
原创粉丝点击