POJ 1741 Tree——点分治

来源:互联网 发布:收购淘宝店铺的网站 编辑:程序博客网 时间:2024/05/16 14:09

题意:给定一棵树,求合法点对的个数,合法点对定义为:两点之间的距离不超过k

思路:

1求树的重心

2以树的重心为根,求此次划分中其余的点到根节点的距离

3计算此次划分的合法点对

4减去在同一子树中的点对(因为这些点对在后面的划分中会重复统计),然后断开根节点,对子树重复1的过程

复杂度为n(logn)^2

点分治的更多介绍参考09年国家集训队论文

#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>using namespace std;const int maxn = 1e4+5;const int INF = 0x3f3f3f3f;bool vis[maxn];int n, k, ans, tot, head[maxn], sz[maxn], dp[maxn], root, dis[maxn], cnt, SIZE;struct Edge { int to, val, next; }edge[maxn<<1];void init() {    ans = tot = 0;    memset(head, -1, sizeof(head));    memset(vis, false, sizeof(vis));}void addedge(int u, int v, int val) {    edge[tot].to = v; edge[tot].val = val; edge[tot].next = head[u];    head[u] = tot++;}void getroot(int u, int pre) {    sz[u] = 1, dp[u] = 0;    for (int i = head[u]; ~i; i = edge[i].next) {        int v = edge[i].to;        if (vis[v] || v == pre) continue;        getroot(v, u);        sz[u] += sz[v];        dp[u] = max(dp[u], sz[v]);    }    dp[u] = max(dp[u], SIZE-sz[u]);    if (dp[u] < dp[root]) root = u;}void getdis(int u, int pre, int d) {    dis[++cnt] = d;    sz[u] = 1;    for (int i = head[u]; ~i; i = edge[i].next) {        int v = edge[i].to, val = edge[i].val;        if (vis[v] || v == pre) continue;        getdis(v, u, d + val);        sz[u] += sz[v];    }}int getans(int u, int d) {    cnt = 0;    int sum = 0;    getdis(u, -1, d);    sort(dis+1, dis+1+cnt);    int i = 1, j = cnt;    while (i < j) {        while (dis[i]+dis[j] > k && i < j) j--;        sum += j - i; i++;    }    return sum;}void solve(int u) {    dp[0] = INF, root = 0;    getroot(u, -1);    ans += getans(root, 0);    vis[root] = true;    for (int i = head[root]; ~i; i = edge[i].next) {        int v = edge[i].to, val = edge[i].val;        if (vis[v]) continue;        ans -= getans(v, val);        SIZE = sz[v];        solve(v);    }}int main() {    while (~scanf("%d %d", &n, &k) && (n+k)) {        int u, v, val;        init();        for (int i = 1; i <= n-1; i++) {            scanf("%d %d %d", &u, &v, &val);            addedge(u, v, val); addedge(v, u, val);        }        SIZE = n;        solve(1);        printf("%d\n", ans);    }    return 0;}