HDU 5977 树的点分治 + 状态压缩 + 枚举子集

来源:互联网 发布:如何做淘宝海外买手 编辑:程序博客网 时间:2024/06/05 14:10

题意:

给一棵节点数为n,节点种类为k的无根树,问其中有多少种不同的简单路径,可以满足路径上经过所有k种类型的点?(a->b与b->a算作两条路径,起点与终点也可以相同)

思路:

现场赛的时候k的大小是7,当时看到这题也没多想就树形dp水过了。现在重现赛k改成了10,这时候用树形dp,无论是时间还是空间复杂度都很爆炸。后来听说这题的正解是树分治,于是就学习了一波,然后重新来做这道题,关于树分治的内容在我上一篇博客中详细介绍了,链接:http://blog.csdn.net/bahuia/article/details/53066373

这题运用树的点分治算法,与POJ-1741的区别就在于后者是求长度小于等于k的路径数目,而这道题是求经过所有种类点的路径,状压一下,也就是求状态为(1<<k)-1的路径数目,其实本质上是一样的,只是从路径权值的加和变成了路径状态的或运算。

这题的难点在于cal()函数,也就是将问题转化成了已知x个数a1,a2,...ax,求其中有多少点对的或运算的和为(1<<k)-1,因为这些都是二进制状态,并没有直接的大小关系,所以POJ-1741那题排序的算法就不能用了,这里我们必须另外想一个O(nlogn)级别的算法。

我们枚举每一个其中的每一个数x,想找到数组中有多少数和x的或运算的和为(1<<k)-1,也就是找到可以包含((1<<k)-1)^x的数,这时候可以反向考虑,先枚举x的子集,然后再与(1<<k)-1进行异或运算,就可以找到了所有的情况。
具体细节看代码。

代码:

#include <bits/stdc++.h>using namespace std;typedef long long ll;const int MAXN = 5e5 + 10;int n, k, Max, root;ll ans;vector <int> tree[MAXN];vector <int> sta;int sz[MAXN], maxv[MAXN], a[MAXN];ll Hash[1200];bool vis[MAXN];void init() {    memset(vis, false, sizeof(vis));    for (int i = 1; i <= n; i++) tree[i].clear();}void dfs_size(int u, int pre) {    sz[u] = 1; maxv[u] = 0;    int cnt = tree[u].size();    for (int i = 0; i < cnt; i++) {        int v = tree[u][i];        if (v == pre || vis[v]) continue;        dfs_size(v, u);        sz[u] += sz[v];        maxv[u] = max(maxv[u], sz[v]);    }}void dfs_root(int r, int u, int pre) {    maxv[u] = max(maxv[u], sz[r] - sz[u]);    if (Max > maxv[u]) {        Max = maxv[u];        root = u;    }    int cnt = tree[u].size();    for (int i = 0; i < cnt; i++) {        int v = tree[u][i];        if (v == pre || vis[v]) continue;        dfs_root(r, v, u);    }}void dfs_sta(int u, int pre, int s) {    sta.push_back(s);    int cnt = tree[u].size();    for (int i = 0; i < cnt; i++) {        int v = tree[u][i];        if (v == pre || vis[v]) continue;        dfs_sta(v, u, s | (1 << a[v]));    }}ll cal(int u, int s) {    ll res = 0;    sta.clear(); dfs_sta(u, -1, s);    memset(Hash, 0, sizeof(Hash));    int cnt = sta.size();    for (int i = 0; i < cnt; i++) Hash[sta[i]]++;    for (int i = 0; i < cnt; i++) {        Hash[sta[i]]--;        res += Hash[(1 << k) - 1];        for (int s0 = sta[i]; s0; s0 = (s0 - 1) & sta[i])            res += Hash[((1 << k) - 1) ^ s0];        Hash[sta[i]]++;    }    return res;}void dfs(int u) {    Max = n;    dfs_size(u, -1); dfs_root(u, u, -1);    ans += cal(root, (1 << a[root]));    vis[root] = true;    int cnt = tree[root].size(), rt = root;    for (int i = 0; i < cnt; i++) {        int v = tree[rt][i];        if (vis[v]) continue;        ans -= cal(v, (1 << a[rt]) | (1 << a[v]));        dfs(v);    }}int main() {    //freopen("in", "r", stdin);    while (scanf("%d%d", &n, &k) == 2) {        init();        for (int i = 1; i <= n; i++) {            scanf("%d", &a[i]);            --a[i];        }        for (int i = 1; i < n; i++) {            int u, v;            scanf("%d%d", &u, &v);            tree[u].push_back(v);            tree[v].push_back(u);        }        if (k == 1) {            printf("%d\n", n * n);            continue;        }        ans = 0;        dfs(1);        printf("%lld\n", ans);    }    return 0;}


0 0