BZOJ 1036 [ZJOI2008]树的统计Count (树链剖分)

来源:互联网 发布:it女神 编辑:程序博客网 时间:2024/05/02 04:53

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=1036
        BZOJ太可怕了,第一页按照AC数排序,除了A+B,第一个就是这题了。。。。看着题目懵逼了很久,不知道怎么做,搜了一发题解,才发现是树链剖分。然而树链剖分这个名词也只是停留在听说过的程度,于是跟着一个大神的博客看了一下午才看懂,然后照着博客打了一次,但是一直是WA。索性重写了一发,然后一发过了,检查了一下才发现是之前的线段树有个地方写错了。
        以前在写题的时候,感觉在有思路以后,写起来还是挺容易的,按照思路一路写下来一般都没什么问题。但是今天碰到这题,代码又长,思路也是有难度的,写的时候明显感觉力不从心。但是这更让我坚定了继续做BZOJ的想法。
        我感觉树链剖分,就是把一个树的结构按照一个一个链的结构分开,然后按照这个规则来编号,那么每一段重链都是连续的编号,就变成了在一个直线上的处理,就可以套线段树等数据结构来搞。这题属于树链剖分入门题,让我认识到了树链剖分的强大,以及BZOJ的可怕。。。接下来几天找找别的树剖的题来做做,消化一下这个算法,加深理解。

#include <cstdio>#include <cstring>#include <algorithm>#include <cmath>#include <cstdlib>#include <cctype>#include <string>#include <iostream>#include <vector>#include <map>#include <queue>#include <ctime>using namespace std;typedef long long LL;typedef pair<int,int> pii;#define PB push_back#define lson l,m,rt<<1#define rson m+1,r,rt<<1|1#define calm (l+r)>>1const int INF=1e9+7;const int maxn=30010;struct EE{    int next,to;    EE(){}    EE(int to,int next):to(to),next(next){}}edge[maxn*2];int n,Ecnt,tot,head[maxn],val[maxn];int top[maxn],fa[maxn],rev[maxn],id[maxn],num[maxn],son[maxn],deep[maxn];inline void addedge(int a,int b){    edge[Ecnt]=EE(b,head[a]);    head[a]=Ecnt++;}void dfs1(int s,int pre,int d){//求解fa(父节点),num(子树节点数量),deep(深度),son(重儿子)    deep[s]=d;fa[s]=pre;num[s]=1;son[s]=0;    for(int i=head[s];~i;i=edge[i].next){        int t=edge[i].to;        if(t==pre)continue;        dfs1(t,s,d+1);num[s]+=num[t];        if(son[s]==0||num[t]>num[son[s]]){            son[s]=t;        }    }}void dfs2(int s,int rt){//求解top(重链起点),id(在线段树中的编号),rev(与id相对)    top[s]=rt;id[s]=++tot;rev[id[s]]=s;    if(son[s]==0)return;//叶子节点    dfs2(son[s],rt);//递归重儿子--重链的编号一定是连续的    for(int i=head[s];~i;i=edge[i].next){        int t=edge[i].to;        if(t==fa[s]||t==son[s])continue;        dfs2(t,t);//递归轻儿子    }}//SegmentTreestruct node{    int MAX,sum;}tree[maxn<<2];inline void pushup(int rt){    tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum;    tree[rt].MAX=max(tree[rt<<1].MAX,tree[rt<<1|1].MAX);}void build(int l,int r,int rt){    if(l==r){        tree[rt].MAX=tree[rt].sum=val[rev[l]];        return;    }    int m=calm;    build(lson);build(rson);    pushup(rt);}void update(int x,int v,int l,int r,int rt){    if(l==r){        tree[rt].MAX=tree[rt].sum=v;        return;    }    int m=calm;    if(x<=m)update(x,v,lson);    if(x>m)update(x,v,rson);    pushup(rt);}int querymax(int L,int R,int l,int r,int rt){    if(L<=l&&r<=R){        return tree[rt].MAX;    }    int m=calm,ans=-INF;    if(L<=m)ans=max(ans,querymax(L,R,lson));    if(R>m)ans=max(ans,querymax(L,R,rson));    return ans;}int querysum(int L,int R,int l,int r,int rt){    if(L<=l&&r<=R){        return tree[rt].sum;    }    int m=calm,ans=0;    if(L<=m)ans+=querysum(L,R,lson);    if(R>m)ans+=querysum(L,R,rson);    return ans;}int findmax(int x,int y){    int f1=top[x],f2=top[y],ans=-INF;    while(f1!=f2){//每次选深度大的往上走,直到x和y在同一条重链中        if(deep[f1]<deep[f2]){            swap(f1,f2);swap(x,y);        }        ans=max(ans,querymax(id[f1],id[x],1,n,1));//因为重链的编号一定是连续的        x=fa[f1];//这一段查询完了,x往上移动至重链起点        f1=top[x];//f1一直保持为x节点的top    }    if(deep[x]>deep[y]){        ans=max(ans,querymax(id[y],id[x],1,n,1));    }    else{        ans=max(ans,querymax(id[x],id[y],1,n,1));    }    return ans;}int findsum(int x,int y){//同上    int f1=top[x],f2=top[y],ans=0;    while(f1!=f2){        if(deep[f1]<deep[f2]){            swap(f1,f2);swap(x,y);        }        ans+=querysum(id[f1],id[x],1,n,1);        x=fa[f1];f1=top[x];    }    if(deep[x]>deep[y]){        ans+=querysum(id[y],id[x],1,n,1);    }    else{        ans+=querysum(id[x],id[y],1,n,1);    }    return ans;}int main(){    //freopen("/home/xt/code/acm/input.txt","r",stdin);    scanf("%d",&n);    memset(head,-1,sizeof head);Ecnt=0;    for(int i=1;i<n;i++){        int a,b;scanf("%d%d",&a,&b);        addedge(a,b);addedge(b,a);    }    for(int i=1;i<=n;i++){        scanf("%d",&val[i]);    }    tot=0;    dfs1(1,0,1);dfs2(1,1);    build(1,n,1);    char op[10];int q,a,b;    scanf("%d",&q);    while(q--){        scanf("%s%d%d",op,&a,&b);        if(op[0]=='C'){            update(id[a],b,1,n,1);            val[a]=b;        }        else if(op[1]=='M'){            printf("%d\n",findmax(a,b));        }        else{            printf("%d\n",findsum(a,b));        }    }    //printf("[Run in %.1fs]\n",(double)clock()/CLOCKS_PER_SEC);    return 0;}
0 0
原创粉丝点击