bzoj 1036

来源:互联网 发布:niconico网络连接失败 编辑:程序博客网 时间:2024/06/16 05:18

树链剖分入门题

#include<cstdio>#include<cstring>const int N=30005;int n;struct qq{    int y,last;}e[N*2];int num,last[N];void init (int x,int y){    num++;    e[num].y=y;    e[num].last=last[x];    last[x]=num;    return ;}struct qr{    int son,fa,top,tot,ys;    int dep;}tr[N];void bt_node (int x){    tr[x].tot=1;tr[x].son=0;    for (int u=last[x];u!=-1;u=e[u].last)    {        int y=e[u].y;        if (y==tr[x].fa) continue;        tr[y].fa=x;        tr[y].dep=tr[x].dep+1;        bt_node(y);        if (tr[tr[x].son].tot<tr[y].tot) tr[x].son=y;        tr[x].tot+=tr[y].tot;    }    return ;}int num1;void bt_edge (int x,int tp){    tr[x].ys=++num1;tr[x].top=tp;    if (tr[x].son!=0) bt_edge(tr[x].son,tp);    for (int u=last[x];u!=-1;u=e[u].last)    {        int y=e[u].y;        if (y==tr[x].son||y==tr[x].fa) continue;        bt_edge(y,y);    }    return ;}struct qa{    int l,r;    int z,z1;    int s1,s2;}s[N*2];int num2;void bt1 (int l,int r){    int a=++num2;    s[a].l=l;s[a].r=r;    s[a].s1=s[a].s2=0;    s[a].z=0;s[a].z1=0;    if (l>=r) return ;    int mid=(l+r)/2;    s[a].s1=num2+1;bt1(l,mid);    s[a].s2=num2+1;bt1(mid+1,r);}void bt (){    tr[1].dep=1;tr[1].fa=0;bt_node(1);    num1=0;bt_edge(1,1);    num2=0;bt1(1,num1);}int mymax (int x,int y){    return x>y?x:y;}void change (int now,int x,int z){    if (s[now].l==s[now].r)    {        s[now].z=s[now].z1=z;        return ;    }    int s1=s[now].s1,s2=s[now].s2;    int mid=(s[now].l+s[now].r)/2;    if (x<=mid) change(s1,x,z);    else change(s2,x,z);    s[now].z=s[s1].z+s[s2].z;    s[now].z1=mymax(s[s1].z1,s[s2].z1);}int findmax (int now,int l,int r){    if (s[now].l==l&&s[now].r==r)        return s[now].z1;    int mid=(s[now].l+s[now].r)/2;    int s1=s[now].s1,s2=s[now].s2;    if (r<=mid) return findmax(s1,l,r);    else if (l>mid) return findmax(s2,l,r);    else  return mymax(findmax(s1,l,mid),findmax(s2,mid+1,r));}int findnum (int now,int l,int r){    if (s[now].l==l&&s[now].r==r)        return s[now].z;    int mid=(s[now].l+s[now].r)/2;    int s1=s[now].s1,s2=s[now].s2;    if (r<=mid) return findnum(s1,l,r);    else if (l>mid) return findnum(s2,l,r);    else  return findnum(s1,l,mid)+findnum(s2,mid+1,r);}int solve (int x,int y,bool tf)//true:z1  false:z{    int fa_x=tr[x].top,fa_y=tr[y].top;    int ans;    if (tf) ans=-1000000000;    else ans=0;    while (fa_x!=fa_y)    {        if (tr[fa_x].dep>tr[fa_y].dep)        {            int tt=x;x=y;y=tt;            tt=fa_x;fa_x=fa_y;fa_y=tt;        }        if (tf) ans=mymax(ans,findmax(1,tr[fa_y].ys,tr[y].ys));        else ans=ans+findnum(1,tr[fa_y].ys,tr[y].ys);        y=tr[fa_y].fa;fa_y=tr[y].top;    }         if (tr[x].dep>tr[y].dep)    {        int tt=x;x=y;y=tt;        tt=fa_x;fa_x=fa_y;fa_y=tt;    }    //printf("%d %d %d %d\n",x,y,ans,findnum(1,tr[x].ys,tr[y].ys));    if (tf) ans=mymax(ans,findmax(1,tr[x].ys,tr[y].ys));    else ans=ans+findnum(1,tr[x].ys,tr[y].ys);    return ans;}int main(){    //freopen("a.out","w",stdout);    num=0;memset(last,-1,sizeof(last));    scanf("%d",&n);    for (int u=1;u<n;u++)    {        int x,y;        scanf("%d%d",&x,&y);        init(x,y);init(y,x);    }    bt();    for (int u=1;u<=n;u++)    {        int a;        scanf("%d",&a);        change(1,tr[u].ys,a);    }     //for (int u=1;u<=num2;u++) printf("%d %d %d %d\n",u,s[u].l,s[u].r,s[u].z);    int m;    scanf("%d",&m);    while (m--)    {        char ss[10];        scanf("%s",ss);         int x,y;        scanf("%d%d",&x,&y);        if (ss[0]=='C')            change(1,tr[x].ys,y);        else if (ss[1]=='M')            printf("%d\n",solve(x,y,true));        else            printf("%d\n",solve(x,y,false));    }    return 0;}


1 0
原创粉丝点击