bzoj 3319: 黑白树 (并查集)

来源:互联网 发布:怎样做淘宝客网站 编辑:程序博客网 时间:2024/05/17 03:46

题目描述

传送门

题目大意:给定一棵树,边的颜色为黑或白,初始时全部为白色。维护两个操作:
1.查询u到根路径上的第一条黑色边的标号。
2.将u到v 路径上的所有边的颜色设为黑色。
Notice:这棵树的根节点为1

题解

先将所有操作正着进行一遍,将所有的黑边相邻的点按照关系合并,就是一个集合中的代表元素一定是深度最小的点。
然后找出所有自始至终都是白色的边,以及每条边变黑的时间。将白边用并查集合并
倒着做所有的操作,对于染黑操作如果我们撤销相当于染白,将是所有在当前操作中变黑的边的两端用并查集合并,可以直接遍历路径,做法与合并黑边时类似。(也可以按照每条边变黑的时间排序,然后直接合并每一条边,这样就不用遍历路径了)
对于每次的查询操作直接找集合的代表元素,代表元素与其父节点之间的边就是答案。

代码

#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>#define N 1000003using namespace std;int tot,point[N],nxt[N*2],v[N*2],c[N*2];int fa[N],belong[N],size[N],son[N],q[N];int n,m,mark[N],pos[N],sz,deep[N],f[N],pd[N],pd1[N],ans[N];struct data{    int opt,x,y;}e[N],p[N];int read()   {      char ch = getchar();      for ( ; ch > '9' || ch < '0'; ch = getchar());      int tmp = 0;      for ( ; '0' <= ch && ch <= '9'; ch = getchar())        tmp = tmp * 10 + int(ch) - 48;      return tmp;   }   void add(int x,int y,int num){    tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=num;    tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=num;}void dfs(int x,int f){    deep[x]=deep[f]+1;    for (int i=point[x];i;i=nxt[i]) {        if (v[i]==f) continue;        int t=c[i];        e[t].x=x; e[t].y=v[i];        fa[v[i]]=x;        dfs(v[i],x);        mark[v[i]]=c[i];    }}int find(int x){    if (f[x]==x) return x;    f[x]=find(f[x]);    return f[x];}void change(int x,int y,int opt){    x=find(x); y=find(y);    while (x!=y) {        if (deep[x]<deep[y]) swap(x,y);        if (!pd[x]) f[x]=f[fa[x]],pd[x]=opt;        x=f[x];    }}void solve(int x,int y,int opt){    x=find(x); y=find(y);    while (x!=y) {        if (deep[x]<deep[y]) swap(x,y);        if (pd[x]==opt) f[x]=f[fa[x]];        x=fa[x];    }}int main(){    freopen("a.in","r",stdin);    freopen("my.out","w",stdout);    scanf("%d%d",&n,&m);    for (int i=1;i<n;i++) {        int x,y; x=read(); y=read();        add(x,y,i);    }    dfs(1,0);    for (int i=1;i<=n;i++) f[i]=i;    for (int i=1;i<=m;i++) {        scanf("%d%d",&p[i].opt,&p[i].x);        if (p[i].opt==2) scanf("%d",&p[i].y),change(p[i].x,p[i].y,i);    }    for (int i=1;i<=n;i++) f[i]=i;    for (int i=1;i<=n;i++) pd1[i]=pd[i];    for (int i=2;i<=n;i++)     if (!pd1[i]) {        int t=mark[i];        int r1=find(e[t].x); int r2=find(e[t].y);        f[r2]=r1;     }     int cnt=0;    for (int i=m;i>=1;i--) {        if (p[i].opt==1) {            int r1=find(p[i].x);            ans[++cnt]=mark[r1];        }        else solve(p[i].x,p[i].y,i);    }    for (int i=cnt;i>=1;i--) printf("%d\n",ans[i]);}
原创粉丝点击