bzoj 2243 染色 树链剖分 解题报告

来源:互联网 发布:买家淘宝怎么发买家秀 编辑:程序博客网 时间:2024/06/08 19:50

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5

Sample Output

3
1
2

思路

树链剖分+线段树就好了。
第一个dfs找size、deep;第二个dfs递归找重儿子,然后连在一起建立种链。lca用来算两个节点之间的路径,然后就是线段树上的事情了。

代码

好像WA了。。。

#include<iostream>#include<cstdlib>#include<cstdio>#include<cstring>using namespace std;const int N=100000+5;int n,m,num,sz,head[N],deep[N],son[N];int belong[N],pl[N],v[N],ft[N][18];bool vis[N];struct seg{    int l,r,lc,rc,s,tag;};seg tree[4*N];struct edge{    int u,next;};edge ed[2*N];void insert(int u,int v){    ed[++num].u=v;    ed[num].next=head[u];    head[u]=num;    ed[++num].u=u;    ed[num].next=head[v];    head[v]=num;}void dfs1(int x){    vis[x]=son[x]=1;    for (int i=1;i<=17;i++)    {        if (deep[x]<(1<<i)) break;        ft[x][i]=ft[ft[x][i-1]][i-1];    }    for (int i=head[x];i;i=ed[i].next)    {        if (vis[ed[i].u]) continue;        deep[ed[i].u]=deep[x]+1;        ft[ed[i].u][0]=x;        dfs1(ed[i].u);        son[x]+=son[ed[i].u];    }}void dfs2(int x,int chain){    int k=0;    pl[x]=++sz;belong[x]=chain;    for (int i=head[x];i;i=ed[i].next)    if (deep[ed[i].u]>deep[x]&&son[k]<son[ed[i].u]) k=ed[i].u;    if (!k) return ;    dfs2(k,chain);    for (int i=head[x];i;i=ed[i].next)    if (deep[ed[i].u]>deep[x]&&k!=ed[i].u)    dfs2(ed[i].u,ed[i].u);}int lca(int x,int y){    if (deep[x]<deep[y]) swap(x,y);    int t=deep[x]-deep[y];    for (int i=0;i<=17;i++)    if (t&(1<<i)) x=ft[x][i];    for (int i=17;i>=0;i--)    if (ft[x][i]!=ft[y][i]) {x=ft[x][i];y=ft[y][i];}    if (x==y) return x;    return ft[x][0];}void build(int k,int lf,int rt){    tree[k].l=lf;    tree[k].r=rt;    tree[k].s=1;    tree[k].tag=-1;    if (lf==rt) return ;    int mid=(lf+rt)/2;    build(k*2,lf,mid);    build(k*2|1,mid+1,rt);}void pushup(int k){    tree[k].lc=tree[k*2].lc;    tree[k].rc=tree[k*2+1].rc;    if (tree[k*2].rc^tree[k*2+1].lc)    tree[k].s=tree[k*2].s+tree[k*2+1].s-1;}void pushdown(int k){    int tmp=tree[k].tag;    tree[k].tag=-1;    if (tmp==-1||tree[k].l==tree[k].r) return ;    tree[k*2].s=tree[k*2+1].s=1;    tree[k*2].tag=tree[k*2+1].tag=tmp;    tree[k*2].lc=tree[k*2+1].rc=tmp;    tree[k*2+1].lc=tree[k*2+1].rc=tmp;}void change(int k,int x,int y,int c){    pushdown(k);    int lf=tree[k].l,rt=tree[k].r;    if (lf==x&&rt==y)    {        tree[k].lc=tree[k].rc=c;        tree[k].s=1;        tree[k].tag=c;        return ;    }    int mid=(lf+rt)/2;    if (mid>=y) change(k<<1,x,y,c);    else if (mid<x) change(k<<1|1,x,y,c);    else    {        change(k*2,x,mid,c);        change(k*2+1,mid+1,y,c);    }    pushup(k);}int ask(int k,int x,int y){    pushdown(k);    int l=tree[k].l,r=tree[k].r;    if(l==x&&r==y)return tree[k].s;    int mid=(l+r)>>1;    if(mid>=y)return ask(k<<1,x,y);    else if(mid<x)return ask(k<<1|1,x,y);    else    {        int tmp=1;        if(tree[k<<1].rc^tree[k<<1|1].lc)tmp=0;        return ask(k<<1,x,mid)+ask(k<<1|1,mid+1,y)-tmp;    }}int getc(int k,int x){    pushdown(k);    int l=tree[k].l,r=tree[k].r;    if(l==r)return tree[k].lc;    int mid=(l+r)>>1;    if(x<=mid)return getc(k<<1,x);    else return getc(k<<1|1,x);}int solvesum(int x,int f){    int sum=0;    while(belong[x]!=belong[f])    {        sum+=ask(1,pl[belong[x]],pl[x]);        if (getc(1,pl[belong[x]])==getc(1,pl[ft[belong[x]][0]]))        sum--;        x=ft[belong[x]][0];     }    sum+=ask(1,pl[f],pl[x]);    return sum;}void solvechange(int x,int f,int c){    while(belong[x]!=belong[f])    {        change(1,pl[belong[x]],pl[x],c);        x=ft[belong[x]][0];     }    change(1,pl[f],pl[x],c);}void solve(){    int a,b,c;    dfs1(1);    dfs2(1,1);    build(1,1,n);    for(int i=1;i<=n;i++)        change(1,pl[i],pl[i],v[i]);    for(int i=1;i<=m;i++)    {        char ch[10];        scanf("%s",ch);        if(ch[0]=='Q')        {            scanf("%d%d",&a,&b);            int t=lca(a,b);            printf("%d\n",solvesum(a,t)+solvesum(b,t)-1);        }        else        {            scanf("%d%d%d",&a,&b,&c);            int t=lca(a,b);            solvechange(a,t,c);solvechange(b,t,c);        }    }}int main(){    scanf("%d%d",&n,&m);    for(int i=1;i<=n;i++)    scanf("%d",&v[i]);    for(int i=1;i<n;i++)    {        int x,y;        scanf("%d%d",&x,&y);        insert(x,y);    }    solve();    return 0;}