csu1811 Tree Intersection(线段树合并)★ ★ ★

来源:互联网 发布:游戏辅助软件制作 编辑:程序博客网 时间:2024/06/07 15:52

题意:对于每条边,把这条边删了,树分成了两个集合,求这两个集合中共同的颜色数量。

对于节点u,就看u节点的子树中,有多少种颜色没有到达这种颜色的上限,就是对u所对应的边的答案


我们用线段树合并来维护,代码写起来比较麻烦。

之所以可以用线段树合并,是因为有一个结论:如果初始的时候只有一个叶子的线段树有n个,那么最后合并成一个线段树,总复杂度只有O(nlogn)

#include <map>#include <set>#include <cmath>#include <ctime>#include <stack>#include <queue>#include <cstdio>#include <cctype>#include <bitset>#include <string>#include <vector>#include <cstring>#include <iostream>#include <algorithm>#include <functional>#define fuck(x) cout<<"["<<x<<"]";#define FIN freopen("input.txt","r",stdin);#define FOUT freopen("output.txt","w+",stdout);//#pragma comment(linker, "/STACK:102400000,102400000")using namespace std;typedef long long LL;typedef pair<int, int>PII; const int MX = 1e5 + 5;const int MD = 2e6 + 5; struct Edge {    int v, nxt;} E[MX << 1];int Head[MX], erear;void edge_init() {    erear = 0;    memset(Head, -1, sizeof(Head));}void edge_add(int u, int v) {    E[erear].v = v;    E[erear].nxt = Head[u];    Head[u] = erear++;} struct Node {    int l, r;    int val, sum;} A[MD];int n, sz, id[MX], C[MX], root[MX];int ans[MX], cnt[MX]; void push_up(int rt) {    A[rt].sum = A[A[rt].l].sum + A[A[rt].r].sum;}int build(int c, int l, int r) {    int rt = ++sz;    A[rt].l = A[rt].r = 0;    A[rt].sum = 0;    if(l == r) {        A[rt].val = 1;        A[rt].sum = (A[rt].val != cnt[l]);        return rt;    }    int m = (l + r) >> 1;    if(c <= m) A[rt].l = build(c, l, m);    else A[rt].r = build(c, m + 1, r);    push_up(rt);    return rt;}void merge(int &rt1, int rt2, int l, int r) {    if(!rt1 || !rt2) {        if(!rt1) rt1 = rt2;        return;    }    if(l == r) {        A[rt1].val += A[rt2].val;        A[rt1].sum = (A[rt1].val != cnt[l]);        return;    }    int m = (l + r) >> 1;    merge(A[rt1].l, A[rt2].l, l, m);    merge(A[rt1].r, A[rt2].r, m + 1, r);    push_up(rt1);}void DFS(int u, int f, int e) {    int doc = 0;    root[u] = build(C[u], 1, n);    for(int i = Head[u]; ~i; i = E[i].nxt) {        int v = E[i].v;        if(v == f) continue;        DFS(v, u, i);        merge(root[u], root[v], 1, n);    }    if(u != 1) {        int id = e / 2 + 1;        ans[id] = A[root[u]].sum;    }} int main() {    // FIN;    while(~scanf("%d", &n)) {        edge_init(); sz = 0;        memset(cnt, 0, sizeof(cnt));         for(int i = 1; i <= n; i++) {            scanf("%d", &C[i]);            cnt[C[i]]++;        }        for(int i = 1; i <= n - 1; i++) {            int u, v;            scanf("%d%d", &u, &v);            edge_add(u, v);            edge_add(v, u);        }        DFS(1, -1, -1);        for(int i = 1; i <= n - 1; i++) {            printf("%d\n", ans[i]);        }    }    return 0;}


0 0