bzoj1036: [ZJOI2008]树的统计Count

来源:互联网 发布:linux sigpipe 编辑:程序博客网 时间:2024/06/06 04:07

传送门

啊,树剖。

本人第一道树剖。

其实就是树链剖分后跑线段树。时间O(nlogn^2)

#include<cmath>#include<cstdio>#include<cstring>#include<iostream>#include<algorithm>#include<cstdlib>#define N 30005using namespace std;struct edge{int to,next;}e[N*2];struct node{int l,r,ma,sum;}t[N*4];int pos[N],top[N],dep[N],fa[N],size[N],head[N],v[N];int tot,n,sz;inline void ins(int x,int y){    e[++tot].to=y;    e[tot].next=head[x];    head[x]=tot;}inline void init(){    scanf("%d",&n);    for (int i=1;i<n;i++){        int x,y;        scanf("%d%d",&x,&y);        ins(x,y);        ins(y,x);    }    for (int i=1;i<=n;i++) scanf("%d",&v[i]);}inline void dfs1(int x,int f,int depth){    dep[x]=depth;    fa[x]=f;    size[x]=1;    for (int i=head[x];i;i=e[i].next)        if (e[i].to!=f){            dfs1(e[i].to,x,depth+1);            size[x]+=size[e[i].to];        }}inline void dfs2(int x,int top1){    sz++;    pos[x]=sz;    top[x]=top1;    int heavy=0;    for (int i=head[x];i;i=e[i].next)        if (e[i].to!=fa[x]&&size[e[i].to]>size[heavy])            heavy=e[i].to;    if (!heavy) return;    dfs2(heavy,top1);    for (int i=head[x];i;i=e[i].next)        if (e[i].to!=fa[x]&&e[i].to!=heavy)            dfs2(e[i].to,e[i].to);}void build(int x,int l,int r){    t[x].l=l;    t[x].r=r;    if (l!=r){        int m=(l+r)/2;        build(x*2,l,m);        build(x*2+1,m+1,r);    }}void change(int k,int x,int y){    int l=t[k].l,r=t[k].r;    if (l==r){        t[k].sum=t[k].ma=y;        return;    }    int mid=(l+r)/2;    if (x<=mid) change(k*2,x,y);    else change(k*2+1,x,y);    t[k].sum=t[k*2].sum+t[k*2+1].sum;    t[k].ma=max(t[k*2].ma,t[k*2+1].ma);}int asksum(int k,int x,int y){    int l=t[k].l,r=t[k].r;    if (l==x&&r==y) return t[k].sum;    int mid=(l+r)/2;    if (y<=mid) return asksum(k*2,x,y);    else if (x>mid) return asksum(k*2+1,x,y);    else return asksum(k*2,x,mid)+asksum(k*2+1,mid+1,y);}int askmax(int k,int x,int y){    int l=t[k].l,r=t[k].r;    if (l==x&&r==y) return t[k].ma;    int mid=(l+r)/2;    if (y<=mid) return askmax(k*2,x,y);    else if (x>mid) return askmax(k*2+1,x,y);    else return max(askmax(k*2,x,mid),askmax(k*2+1,mid+1,y));}int solvesum(int x,int y){    int sum=0;    while (top[x]!=top[y]){        if (dep[top[x]]<dep[top[y]]) swap(x,y);        sum+=asksum(1,pos[top[x]],pos[x]);        x=fa[top[x]];    }    if (pos[x]>pos[y]) swap(x,y);    sum+=asksum(1,pos[x],pos[y]);    return sum;}int solvemax(int x,int y){    int ma=-50000;    while (top[x]!=top[y]){        if (dep[top[x]]<dep[top[y]]) swap(x,y);        ma=max(ma,askmax(1,pos[top[x]],pos[x]));        x=fa[top[x]];    }    if (pos[x]>pos[y]) swap(x,y);    ma=max(ma,askmax(1,pos[x],pos[y]));    return ma;}inline void solve(){    int q,x,y;    build(1,1,n);    for (int i=1;i<=n;i++) change(1,pos[i],v[i]);    scanf("%d",&q);    char ch[10];    while (q--){        //printf("%d\n",q);        scanf("%s%d%d",ch,&x,&y);        if (ch[0]=='C'){v[x]=y; change(1,pos[x],y);}        else if (ch[1]=='M') printf("%d\n",solvemax(x,y));        else printf("%d\n",solvesum(x,y));    }}int main(){    init();    dfs1(1,-1,0);    dfs2(1,1);    solve();}


2 0
原创粉丝点击