CSU 1811 Tree Intersection(莫队算法)

来源:互联网 发布:嘀哩嘀哩 知乎 编辑:程序博客网 时间:2024/05/22 11:59

题意:给出一棵n个节点的树,每个节点有一种颜色,用Ci表示,问如果删除一条边(x,y),那么剩余两棵树的颜色交集的大小是多少。

 

思路:首先对原来的颜色dfs一遍生成有向树得出颜色序列(可参照 挑战程序竞赛 P330),记录每个节点访问时最先出现和最后出现的位置,那么对于删除(x,y)这条边时,假设x是y子节点,很显然,x最先出现和最后出现的位置形成的闭区间则是以x为根节点的子树的所有节点颜色序列,假设这个区间为L,R,要计算两颗子树交集大小,发现交集大小为 = 以x为根节点的子树的所有颜色总数(A) - 以x为根节点的子树只有的颜色总数(B),要计算A的值,等价于在区间[L,R]寻找不同的数的个数,莫队算法直接过,同样B也可以在求A的同时计算出来。


#include<cstdio>#include<cstring>#include<vector>#include<set>#include<queue>#include<cmath>#include<stack>#include<iostream>#include<algorithm>typedef long long ll;const int INF = 1e9 + 7;const int maxn = 1e5 + 10;using namespace std;struct edge {    int from, to;    void input() { scanf("%d %d", &from, &to); }} e[maxn];struct P {    int bol, l, r, num;    P() {}    P(int b, int l, int r, int n) :        bol(b), l(l), r(r), num(n) {}    bool operator < (P p) const {        if(bol != p.bol) return bol < p.bol;        return r < p.r;    }} qry[maxn];int n, q, num, k, INIT;int id[maxn], vs[maxn * 10]; ///每个节点第一次出现的下标,  颜色的遍历序列int col[maxn], sum[maxn]; ///每个点的颜色,dfs序列点的总数int blk[maxn]; ///每个端点所在的块vector<int> G[maxn];int res[maxn]; ///每个查询的结果int now[maxn]; ///子树现有的颜色总数int idx[maxn], dep[maxn]; ///每个点最后出现的下标 深度void init() {    int b = 1, x = 0;    for(int i = 0; i < maxn; i++) {        G[i].clear();        sum[i] = now[i] = 0;    }    if(INIT) return ;    INIT = 1;    while(x < maxn) {        for(int j = 0; j < 300; j++) {            if(x >= maxn) break;            blk[x] = b; x++;        }        b++;    }}void dfs(int v, int fa, int &k, int d) {    id[v] = idx[v] = k; dep[v] = d;    vs[k++] = col[v];    for(int i = 0; i < G[v].size(); i++) {        int to = G[v][i];        if(to == fa) continue;        dfs(to, v, k, d + 1);        idx[v] = k;        vs[k++] = col[v];    }}int tal, only;///一棵子树中 总的颜色数量-只有的颜色数量=交集的数量void solve(int l, int r, int L, int R, int x) {    while(l > L) {        l--;        int c = vs[l];        now[c]++;        if(now[c] == 1) tal++;        if(now[c] == sum[c]) only++;    }    while(l < L) {        int c = vs[l]; now[c]--;        if(!now[c]) tal--;        if(now[c] == sum[c] - 1) only--;        l++;    }    while(r > R) {        int c = vs[r]; now[c]--;        if(!now[c]) tal--;        if(now[c] == sum[c] - 1) only--;        r--;    }    while(r < R) {        r++;        int c = vs[r];        now[c]++;        if(now[c] == 1) tal++;        if(now[c] == sum[c]) only++;    }    res[x] = tal - only;}int main() {    INIT = 0;    while(scanf("%d", &n) != EOF) {        init();        for(int i = 1; i <= n; i++)            scanf("%d", &col[i]);        for(int i = 0; i < n - 1; i++) {            e[i].input();            G[e[i].from].push_back(e[i].to);            G[e[i].to].push_back(e[i].from);        }        k = 0;        dfs(1, 0, k, 1);        for(int i = 0; i < n - 1; i++) {            int u = e[i].from, v = e[i].to, num = i;            if(dep[u] < dep[v]) swap(u, v);            int L = id[u], R = idx[u];            qry[i] = P(blk[L], L, R, i);        }        for(int i = 0; i < k; i++) sum[vs[i]]++;        sort(qry, qry + n - 1);        int c = vs[0]; now[c]++;        int lasl = 0, lasr = 0;        int nowl, nowr;        tal = 1; only = sum[c] == 1 ? 1 : 0;        for(int i = 0; i < n - 1; i++) {            nowl = qry[i].l; nowr = qry[i].r;            int x = qry[i].num;            solve(lasl, lasr, nowl, nowr, x);            lasl = nowl; lasr = nowr;        }        for(int i = 0; i < n - 1; i++) printf("%d\n", res[i]);    }    return 0;}


阅读全文
0 0
原创粉丝点击