HDU 6035 Colorful Tree 树上统计 联通块

来源:互联网 发布:旋转矩阵计算旋转角度 编辑:程序博客网 时间:2024/06/07 06:19

地址

http://acm.hdu.edu.cn/showproblem.php?pid=6035

题意

树上每个节点有一种颜色ci(1<=ci<=n)2<=n<=2105,每个点对的路径值为这个路径上的颜色种数,求树上所有路径(n(n1)/2条路径)的长度和

思路

直接计算每个颜色对答案的贡献

自己的代码就是这样写的,感觉自己的思路说出来不是很好理解 = =,大家看不懂的话可以看第二种思路。

首先如果树上每个点的颜色都不同的话,那么答案就是每个点经过路径的数量之和,可是在问题中有的点的颜色会相同,所以有的路径不能走。

考虑用dfs的方法解题,dfs中我们遇到的第一个点,因为之前没有遇到其它点,所以显然这个点对答案的贡献就是经过它的所有路径(各个子树大小相乘),接着往下dfs,如果其子树中节点的颜色与第一个点的颜色不同,那么自然再统计一次就好(各个子树大小相乘,其中子树包括其父节点),但是有可能这个节点的颜色和第一个节点的颜色相同,那么这个节点向上只有一个联通块可以走,但是向下仍然可以访问其所有子节点。所以此题的关键是维护好与一个节点颜色相同的父节点之间的联通块大小,这个联通块的定义大致就是,一个父节点的子树大小,减去所有与其颜色相同的子节点子树大小。一父节点下方的联通块大小首先是其子树的大小,然后遇到一个相同颜色的子树,就减去这个子树的大小即可。考虑到某一个颜色的第一个节点没有与其颜色相同的父节点,所以加一个虚根。

联通块
如图,点1是父节点,点2对应的父节点联通块就是粉色部分,点3对应的联通块是橙色部分(因为点4还没有访问,所以访问了这里没有关系)点4是蓝色部分

自己的统计方法是先算父节点联通块与自身代表子树的路径条数,然后算自身子树的路径条数。

计算出答案上限,减去非法值

很多网上的代码和标程都是这样写的。

首先,颜色不会超过n种,那么我们假设每条路径上都有n种颜色,共有n(n1)/2条路径,答案就为ans=nn(n1)/2。我们从中减去非法的,或者说重复计算的值即为答案。

如何统计重复计算的值呢?利用前面说的联通块,一个联通块里面是不会出现与这个父节点颜色相同的节点的,设某个联通块的大小为siz,遍历某个颜色的所有节点,这些节点对应的联通块内部的路径都不会出现这个颜色,所以答案要减去siz(siz1)。统计结束后答案就被计算出了。

标程用到了dfs序的方法来统计联通块的大小(用dfs序来判断节点之间的父子关系)。

恕本人精力有限,不提供代码了QAQ

代码

#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>#include <vector>using namespace std;#define PB push_back#define MS(x, y) memset(x, y, sizeof(x));typedef long long LL;const int MAXN = 2e5 + 5;int n;int val[MAXN], col[MAXN], cnt[MAXN << 1];int siz[MAXN];LL ans;vector<int> edges[MAXN];void dfs1(int u, int fa) {  siz[u] = 1;  for (int v: edges[u]) {    if (v == fa) continue;    dfs1(v, u);    siz[u] += siz[v];  }}void dfs2(int u, int fa) {  LL sum = siz[u] - 1;  int pre = col[val[u]];  col[val[u]] = u;  cnt[pre] -= siz[u];  ans += 1LL * (cnt[pre] + 1) * siz[u] - 1;  for (int v: edges[u]) {    if (v == fa) continue;    sum -= siz[v];    ans += siz[v] * sum;    cnt[u] = siz[v];    dfs2(v, u);  }  col[val[u]] = pre;}int main() {  int kase = 0;  while (~scanf("%d", &n)) {    for (int i = 1; i <= n; ++i) {      scanf("%d", val + i);      col[i] = i + n;      cnt[i] = 0;      cnt[i + n] = n;      edges[i].clear();    }    int u, v;    for (int i = 1; i < n; ++i) {      scanf("%d%d", &u, &v);      edges[u].PB(v);      edges[v].PB(u);    }    ans = 0;    dfs1(1, 0);    siz[0] = n + 1;    dfs2(1, 0);    printf("Case #%d: %I64d\n", ++kase, ans);  }}