BZOJ 1036: [ZJOI2008]树的统计Count 树链剖分+线段树

来源:互联网 发布:xp桌面修复软件 编辑:程序博客网 时间:2024/06/05 00:13

1036: [ZJOI2008]树的统计Count

Time Limit: 10 Sec Memory Limit: 162 MB
Submit: 17419 Solved: 7111
[Submit][Status][Discuss]
Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4

1 2

2 3

4 1

4 2 1 3

12

QMAX 3 4

QMAX 3 3

QMAX 3 2

QMAX 2 3

QSUM 3 4

QSUM 2 1

CHANGE 1 5

QMAX 3 4

CHANGE 3 6

QMAX 3 4

QMAX 2 4

QSUM 3 4
Sample Output

4

1

2

2

10

6

5

6

5

16

一个像板一样的东西。以前打过板这次可以当个整理了23333之前还傻逼地以为要求LCA
傻逼long long(不用开long long) 但求max的初始值要小一点(负数)

//一定要强迫自己不要用st表 #include<cstdio>#include<cstring>#include<algorithm>#define ms(x,y) memset(x,y,sizeof(x))using namespace std;const int N = 100010;const int INF = 0x73f3f3f;int n,m;int a[N];struct node{    int pre,v;}edge[N];int num=0;int head[N];void addedge(int from,int to){    num++;    edge[num].pre=head[from];    edge[num].v=to;    head[from]=num;}int indx=0;int dep[N];int in[N],out[N],seq[N],son[N],siz[N],fa[N];void dfs1(int u,int f,int d){    siz[u]=1,fa[u]=f,dep[u]=d;    for(int i=head[u];i;i=edge[i].pre){        int v=edge[i].v;        if(v==f) continue;        dfs1(v,u,d+1);        siz[u]+=siz[v];        if(son[u]==-1||siz[v]>siz[son[u]]) son[u]=v;    }}int top[N];int al[N];void dfs2(int u,int tp){    indx++;    in[u]=out[u]=indx,seq[indx]=u;    al[in[u]]=a[u];    top[u]=tp;    if(son[u]==-1) return ;    dfs2(son[u],tp);    for(int i=head[u];i;i=edge[i].pre){        int v=edge[i].v;        if(v==son[u]||v==fa[u]) continue;        dfs2(v,v);    }    out[u]=indx;}inline int Max(int a,int b){    return a>b?a:b;}struct Node{    int sum;    int mmax,flag;    int l,r;}t[N];void build(int root,int l,int r){    t[root].l=l,t[root].r=r;    if(l==r){        t[root].mmax=t[root].sum=al[l];        return ;    }    int mid=(l+r)>>1;    build(root<<1,l,mid);    build(root<<1|1,mid+1,r);    t[root].sum=t[root<<1].sum+t[root<<1|1].sum;    t[root].mmax=Max(t[root<<1].mmax,t[root<<1|1].mmax);}/*int getlca(int u,int v){    while(1){        if(top[u]==top[v]){            if(dep[u]<dep[v]) return u;            return v;        }        else if(dep[top[u]]>=dep[top[v]]) u=fa[top[u]];        else v=fa[top[v]];    }}*/void modify(int root,int pos,int delta){    int l=t[root].l,r=t[root].r;    if(l==r&&l==pos){        t[root].sum=t[root].mmax=delta;        return ;    }    int mid=(l+r)>>1;    if(pos<=mid) modify(root<<1,pos,delta);    else modify(root<<1|1,pos,delta);    t[root].sum=t[root<<1].sum+t[root<<1|1].sum;    t[root].mmax=Max(t[root<<1].mmax,t[root<<1|1].mmax);}int getmax(int root,int pos,int val){    int l=t[root].l,r=t[root].r;    if(pos<=l&&val>=r) return t[root].mmax;    int mid=(l+r)>>1;    int ans=-INF;    if(pos<=mid) ans=Max(ans,getmax(root<<1,pos,val));    if(val>mid) ans=Max(ans,getmax(root<<1|1,pos,val));    t[root].mmax=Max(t[root<<1].mmax,t[root<<1|1].mmax);    return ans;}int getmax(int u,int v){    int f1=top[u],f2=top[v];    int mmax=-INF;    while(f1!=f2){        if(dep[f1]<dep[f2]) swap(f1,f2),swap(u,v);        mmax=Max(mmax,getmax(1,in[f1],in[u]));        u=fa[f1];f1=top[u];    }    if(dep[u]<dep[v]) swap(u,v);    mmax=Max(mmax,getmax(1,in[v],in[u]));    return mmax;}int getsum(int root,int pos,int val){    int l=t[root].l,r=t[root].r;    if(pos<=l&&val>=r) return t[root].sum;    int mid=(l+r)>>1;    int ans=0;    if(pos<=mid) ans+=getsum(root<<1,pos,val);    if(val>mid) ans+=getsum(root<<1|1,pos,val);    t[root].sum=t[root<<1].sum+t[root<<1|1].sum;    return ans;}int getsum(int u,int v){    int f1=top[u],f2=top[v];    int sum=0;    while(f1!=f2){        if(dep[f1]<dep[f2]) swap(f1,f2),swap(u,v);        sum+=getsum(1,in[f1],in[u]);        u=fa[f1],f1=top[u];    }    if(dep[u]>dep[v]) swap(u,v);    sum+=getsum(1,in[u],in[v]);    return sum;}int main(){    ms(son,-1);    scanf("%d",&n);    for(register int i=1;i<n;i++){        int u,v;        scanf("%d%d",&u,&v);        addedge(u,v);addedge(v,u);    }    for(register int i=1;i<=n;i++) scanf("%d",&a[i]);       dfs1(1,-1,0);dfs2(1,1);    build(1,1,n);    scanf("%d",&m);    while(m--){        char s[10];        scanf("%s",s);        if(s[0]=='C'){            int u,delta;            scanf("%d%d",&u,&delta);            modify(1,in[u],delta);        }        else if(s[0]=='Q'&&s[1]=='M'){            int u,v;            scanf("%d%d",&u,&v);            printf("%d\n",getmax(u,v));        }        else{            int u,v;            scanf("%d%d",&u,&v);            int ans=getsum(u,v);            printf("%d\n",ans);        }    }    return 0;}