HUD 6035 Colorful Tree dfs序||树形dp

来源:互联网 发布:公路基础数据库系统 编辑:程序博客网 时间:2024/06/10 19:14

传送门:HDU6035

题意:给出一颗树,每个节点有不同的颜色,定义树上路径的长度为一条路径上不同颜色的数量,问所有路径的总长度是多少。

思路:这是17年多校第一场的1003,当时看完题还以为是分治什么的,没什么清晰地思路,赛后给的官方题解是:


单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。反过来思考只需要求有多少条路径没有经过这种颜色即可。直接做可以采用虚树的思想(不用真正建出来),对每种颜色的点按照 dfs 序列排个序,就能求出这些点把原来的树划分成的块的大小。这个过程实际上可以直接一次 dfs 求出。


大致是用dfs序瞎搞,然而还是不明白,怼着标程看了半天才稍微有点眉目,然后去网上一搜,发现dalao的博客都说是树形dp。总的来说大致有以下两种解法:

不管哪种解法,首先都是先假设每种颜色都被每条路径经过,然后在慢慢处理减掉那些不经过某种颜色的路径。

解法一:

单独考虑每种颜色,假设将i颜色的点全部去掉后,原树就被分成了很多颗子树,这些子树内部的任何一条路径都是不经过颜色i的,因此要减去,设sum[i]表示以颜色i为根节点的子树的大小(节点数),在dfs过程中动态维护sum[i],假设当前处理到点v,v的颜色为col[v],v的子树大小为son[v],那么有:

  • 不经过col[v]的路径的两端x,y一定在v的同侧
  • 对于v的祖先u,若col[v]==col[u],则在u上计算时整棵子树v都要被排除。
还有就是要注意往sum[col[u]]上加son[u]的时候,如果该子树在递归过程中内部已经有一部分被加到sum[col[u]]中了,那么要注意去重。
代码:
#include<bits/stdc++.h>#define ll long long#define pb push_back#define fi first#define se second#define pi acos(-1)#define inf 0x3f3f3f3f#define lson l,mid,rt<<1#define rson mid+1,r,rt<<1|1#define rep(i,x,n) for(int i=x;i<n;i++)#define per(i,n,x) for(int i=n;i>=x;i--)using namespace std;typedef pair<int,int>P;const int MAXN=200010;int gcd(int a,int b){return b?gcd(b,a%b):a;}int col[MAXN], son[MAXN], sum[MAXN];vector<int> g[MAXN];ll ans;void dfs(int u, int fa){son[u] = 1;int tmp = sum[col[u]], tot = 0;for(int i = 0; i < g[u].size(); i++){int v = g[u][i];if(v == fa) continue;dfs(v, u);son[u] += son[v];ll t = sum[col[u]] - tmp;tot += t;t = son[v] - t;ans -= t * (t - 1) / 2;tmp = sum[col[u]];}sum[col[u]] += son[u] - tot;// tot 是在子树递归中已经加到sum[col[u]]里的子树的大小,这里减去是为了避免重复 }int main(){int n, u, v, kase = 1;while(cin >> n){for(int i = 1; i <= n; i++)scanf("%d", col + i), g[i].clear(), sum[i] = 0;for(int i = 1; i < n; i++){scanf("%d %d", &u, &v);g[u].pb(v);g[v].pb(u);}ans = 1ll * n * n * (n - 1) / 2; dfs(1, -1);for(int i = 1; i <= n; i++){ll t = n - sum[i];ans -= t * (t - 1) / 2;}printf("Case #%d: %lld\n", kase++, ans);} return 0;}


解法二:
记录每种颜色对应的的顶点,然后将顶点按dfs序标号,然后再对于每一种颜色处理其对应的顶点对应的子树,处理过程大体还是按照解法一的两个要点去做。
代码(标程):
#include <bits/stdc++.h>using namespace std;typedef long long LL;const int N = 200005;int n , ca;vector<int> e[N] , c[N];int L[N] , R[N] , s[N] , f[N];void dfs(int x , int fa , int &&ncnt) {    L[x] = ++ ncnt;    s[x] = 1 , f[x] = fa;    for (auto &y : e[x]) {        if (y != fa) {            dfs(y , x , move(ncnt));            s[x] += s[y];        }    }    R[x] = ncnt;}bool cmp(const int& x , const int& y) {    return L[x] < L[y];}void work() {    for (int i = 0 ; i <= n ; ++ i) {        c[i].clear();        e[i].clear();    }    for (int i = 1 ; i <= n ; ++ i) {        int x;        scanf("%d" , &x);        c[x].push_back(i);    }    for (int i = 1 ; i < n ; ++ i) {        int x , y;        scanf("%d%d" , &x , &y);        e[x].push_back(y);        e[y].push_back(x);    }    e[0].push_back(1);    dfs(0 , 0 , 0);    LL res = (LL)n * n * (n - 1) / 2;    for (int i = 1 ; i <= n ; ++ i) {        if (c[i].empty()) {            res -= (LL)n * (n - 1) / 2;            continue;        }        c[i].push_back(0);        sort(c[i].begin() , c[i].end() , cmp);        for (auto &x : c[i]) {            for (auto &y : e[x]) {                if (y == f[x])                    continue;                int size = s[y];                int k = L[y];                while (1) {                    L[n + 1] = k;                    auto it = lower_bound(c[i].begin() , c[i].end() , n + 1 , cmp);//二分是找出c[i]中使得L[*it]                    if (it == c[i].end() || L[*it] > R[y]) {//大于等于L[n+1]的第一个*it                        break;                    }                    size -= s[*it];                    k = R[*it] + 1;                }                //printf("%d] %d %d : %d\n" , i , x , y , size);                res -= (LL)size * (size - 1) / 2;            }        }    }    printf("Case #%d: %lld\n" , ++ ca , res);}int main() {    while (~scanf("%d" , &n)) {        work();    }    return 0;}