HDU 6035 colorful tree

来源:互联网 发布:组策略 禁止软件运行 编辑:程序博客网 时间:2024/05/29 17:03

题目链接

HDU 6035 colorful tree

分析

参考

首先,我们单独考虑每种颜色,对于颜色 c  来说,他对答案的贡献就是包含这种颜色的路径条数,反过来想就是 n(n1)/2,怎么求呢?我们以颜色 c 来对树分块,那么 就是用总路径条数 n(n1)/2 剪掉 每一块的路径数目.不过我们只需要一次dfs就可以求出

代码

#include<bits/stdc++.h>using namespace std;#define pb push_back#define mkp make_pair#define fi first#define se second#define ll long long#define M 1000000007#define all(a) a.begin(), a.end()const int maxn = 200100;int n, ca, c[maxn], dfn[maxn], tot;ll sz[maxn], ans;vector<int> ed[maxn];stack<int> vec[maxn];void dfs(int t, int fa){    dfn[t] = ++tot;    sz[t] = 1;    for(auto v : ed[t]){        if(v == fa) continue;        dfs(v, t);        sz[t] += sz[v];        int tmp = sz[v];        while(!vec[c[t]].empty() && dfn[vec[c[t]].top()] > dfn[t]){            tmp -= sz[vec[c[t]].top()];            vec[c[t]].pop();        }        ans -= (ll)tmp * (tmp - 1) / 2;    }    vec[c[t]].push(t);}int main(){    while(~scanf("%d", &n)){        for(int i = 1; i <= n; ++i) scanf("%d", c + i);        for(int i = 1; i < n; ++i){            int u, v;            scanf("%d%d", &u, &v);            ed[u].pb(v);            ed[v].pb(u);        }        tot = 0;        ans = (ll)n * n * (n - 1) / 2;        dfs(1, 0);        for(int i = 1; i <= n; ++i){            ed[i].clear();            int tmp = n;            while(!vec[i].empty()){                tmp -= sz[vec[i].top()];                vec[i].pop();            }            ans -= (ll)tmp * (tmp - 1) / 2;        }        printf("Case #%d: %lld\n", ++ca, ans);    }    return 0;}

dfs剪掉的是每种颜色分块后第一个顶点到叶子顶点构成的树中的块,具体来说是用颜色栈vec 记录每种颜色的访问顺序,剪掉到访问顶点u 的最近且和u 颜色相同的点,for循环减去的是每种颜色分块后第一个顶点到根上部分.

不过这份代码会超内存,因为颜色栈 vec 每次计算更新内存太大

AC code

g改进一下我们用 sum 来存每一块的累加值,用pre来记录访问 子节点 以前的变化值,再用访问后sun值来剪掉这个pre就是离他最近子块的节点数目,用总的子树节点数剪掉这个值就是当前块的节点数目.

#include<bits/stdc++.h>#define pb push_back#define mp make_pair#define PI acos(-1)#define fi first#define se second#define INF 0x3f3f3f3f#define INF64 0x3f3f3f3f3f3f3f3f#define random(a,b) ((a)+rand()%((b)-(a)+1))#define ms(x,v) memset((x),(v),sizeof(x))using namespace std;const int MOD = 1e9+7;const double eps = 1e-8;typedef long long LL;typedef long double DB;typedef pair<int,int> PII;const int maxn = 2e5+10;const int MAX_V = 1e5+10;int c[maxn];std::vector<int> G[maxn];int sz[maxn];int sum[maxn];LL ans;void dfs(int u,int fa) {    sz[u] = 1;    LL pre = sum[c[u]];    LL add = 0;    for(auto v : G[u]){        if(v == fa)continue;        dfs(v,u);         sz[u] += sz[v];         LL tmp = sz[v];        // while (!vec[c[u]].empty() && dfn[vec[c[u]].top()] > dfn[v]) {        //     tmp -= sz[vec[c[u]].top()];        //     vec[c[u]].pop();        // }        // ans -= tmp*(tmp-1)/2;        tmp -= (sum[c[u]] - pre);        ans -= tmp * (tmp - 1) /2;        add += tmp;        pre  = sum[c[u]];    }    sum[c[u]] += add+1;}int main() {    std::ios::sync_with_stdio(false);    std::cin.tie(0);    int n;    int kase =0;    c[0] = 0;    while (cin>>n) {        for(int i=1 ; i<=n ; ++i)cin>>c[i];        for(int i=1 ; i<n ; ++i){            int u,v;cin>>u>>v;            G[u].pb(v);G[v].pb(u);        }        ans = (LL)n*n*(n-1)/2;        ms(sum,0);        dfs(1,0);        for(int i=1 ; i<=n ; ++i){            LL tmp = n - sum[i];            ans -= tmp*(tmp-1) /2;        }        for(int i=1 ; i<=n ; ++i)G[i].clear();        std::cout << "Case #" << ++kase << ": "<< ans<< '\n';    }    return 0;}