[CSU 1915 John and his farm]树形DP+LCA

来源:互联网 发布:淘宝开放平台人工客服 编辑:程序博客网 时间:2024/06/06 19:10

[CSU 1915 John and his farm]树形DP+LCA

分类:Tree DP LCA

1. 题目链接

[CSU 1915 John and his farm]

2. 题意描述

有一棵N个节点的树,树上每条边长度为1。
现在需要等概率地随机地在两个顶点之间加一条边。
M次询问。
每次查询给定两个顶点u,v。求增加一条边,保证顶点u,v在一个环内的条件下,求环的长度的数学期望。
要求结果保证误差在106以内。
数据范围:(2N,M200000)

3. 解题思路

首先,请见下图。
这里写图片描述
然后,现在的问题就是求sum,siz数组。首先,dfs求出所有节点为根节点的子树的sum,siz。
另外,还需要一个dfs,求出以当前节点为根节点时,整棵树的dep之和,记为all。转移方程是:all[v] = all[u] + (n - 1 - siz[v]) - (siz[v] - 1);

现在就要分两种情况讨论(假设dep[u]<=dep[v]):

  • lca(u, v) !=u, 这个情况比较简单,T(u),T(v) 的sum(T(u)), sum(T(v)), siz(T(u)), siz(T(v)),直接就是sum[u], sum[v], siz[u], siz[v]。
  • lca(u, v)==u, 此时T(v) 的sum(T(v))=sum[v], siz(T(v))=siz[v],但是,siz(T(u)) = n - siz[w], sum(T(u))=all[u] - sum[w] - siz[w]; (顶点w是在从u到v的链上,且是u的儿子)。

看起来比较复杂,但是自己手算理解一下,就很简单了。
这题,还需要注意sum和all 会爆int。

4. 实现代码

#include <queue>#include <stack>#include <ctime>#include <cmath>#include <cctype>#include <cstdio>#include <string>#include <cstring>#include <iomanip>#include <iostream>#include <algorithm>using namespace std;typedef long long LL;typedef long double LB;typedef pair<int, int> PII;typedef pair<LL, LL> PLL;typedef vector<int> VI;const int INF = 0x3f3f3f3f;const LL INFL = 0x3f3f3f3f3f3f3f3fLL;void debug() { cout << endl; }template<typename T, typename ...R> void debug (T f, R ...r) { cout << "[" << f << "]"; debug (r...); }const int MAXN = 1e5 + 5;const int MAXM = 20;int n, m;struct Edge {    int v, next;} edge[MAXN << 1];int head[MAXN], tot;int dep[MAXN], siz[MAXN], fa[MAXN][MAXM];LL all[MAXN], sum[MAXN];int root;void init_edge() {    tot = 0;    memset(head, -1, sizeof(head));}inline void add_edge(int u, int v) {    edge[tot] = Edge {v, head[u]};    head[u] = tot ++;}void dfs(int u, int pre, int d) {    int v;    siz[u] = 1;    dep[u] = d;    sum[u] = 0;    fa[u][0] = pre;    for(int i = head[u]; ~i; i = edge[i].next) {        v = edge[i].v;        if(v == pre) continue;        dfs(v, u, d + 1);        siz[u] += siz[v];        sum[u] += sum[v];        sum[u] += siz[v];    }}void dfs2(int u, int pre) {    int v;    for(int i = head[u]; ~i; i = edge[i].next) {        v = edge[i].v;        if(v == pre) continue;        all[v] = all[u] + (n - 1 - siz[v]) - (siz[v] - 1);        dfs2(v, u);    }}void lca_init() {    for(int j = 1; j < MAXM; ++j) {        for(int i = 1; i <= n; ++i) {            fa[i][j] = fa[fa[i][j - 1]][j - 1];        }    }}int lca(int u, int v) {    while(dep[u] != dep[v]) {        if(dep[u] < dep[v]) swap(u, v);        int d = dep[u] - dep[v];        for(int i = 0; i < MAXM; i++) {            if(d >> i & 1) u = fa[u][i];        }    }    if(u == v) return u;    for(int i = MAXM - 1; i >= 0; i--) {        if(fa[u][i] != fa[v][i]) {            u = fa[u][i];            v = fa[v][i];        }    }    return fa[u][0];}int son(int u, int v) {    while(dep[v] > dep[u] + 1) {        int w = v;        for(int j = 0; j < MAXM; ++j) {            if(dep[fa[v][j]] < dep[u] + 1) break;            w = fa[v][j];        }        v = w;    }    return v;}int main() {#ifdef ___LOCAL_WONZY___    freopen ("input.txt", "r", stdin);#endif // ___LOCAL_WONZY___    int u, v, w;    while(~scanf("%d %d", &n, &m)) {        init_edge();        for(int i = 1; i <= n - 1; ++i) {            scanf("%d %d", &u, &v);            add_edge(u, v);            add_edge(v, u);        }        dfs(root = 1, 0, 0);        all[root] = sum[root];        dfs2(root, 0);        lca_init();//        for(int i = 1; i <= n; ++i) {//            printf("[%d: sum=%d siz=%d all=%d]\n", i, sum[i], siz[i], all[i]);//        }        while(m --) {            scanf("%d %d", &u, &v);            if(dep[u] > dep[v]) swap(u, v);            w = lca(u, v);            int dist, sizu, sizv;            LL sumu, sumv;            double ans;            if(w != u) {                /** 有lca **/                dist = dep[u] + dep[v] - 2 * dep[w];                sizu = siz[u];                sumu = sum[u];                sizv = siz[v];                sumv = sum[v];            } else {                /**一条链**/                dist = dep[v] - dep[u];                w = son(u, v);                sizu = n - siz[w];                sumu = all[u] - sum[w] - siz[w];                sizv = siz[v];                sumv = sum[v];            }            ans = 1.0 + dist + 1.0 * ((LL)sizu * sumv + (LL)sizv * sumu) / ((LL)sizu * sizv);            printf("%.8f\n", ans);        }    }#ifdef ___LOCAL_WONZY___    cout << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC * 1000 << " ms." << endl;#endif // ___LOCAL_WONZY___    return 0;}
1 0
原创粉丝点击