Codechef Prime Distance On Tree(点分治+FFT)

来源:互联网 发布:java 成员信息管理 编辑:程序博客网 时间:2024/05/18 01:01

题外话


最近做题发现自己非常SB,总是查一个SB错误查N久,简直绝望啊。。。弱逼为何而战

这次是忘记加long long查了N久。。蛋碎无比

不过好歹是又做出一道cc hard的题了呢,感人肺腑

Description


题意很简单:

一棵树,问多少个二元组u,v,满足u 到 v的路径长度为素数的概率为多少。所有边长度为1

Solution

自从重温了下 楼教的男人八题后,这种关于路径长度的题一看就是个点分治嘛

显然我们可以点分治统计路径长度,但是显然无法用two pointers或者数据结构搞

怎么办呢?还是个很裸的问题,统计完长度直接FFT就可以了

50000O(nlog2n)

裸裸的水题= =

Code

#include <bits/stdc++.h>using namespace std;const int N = 50005;typedef complex<double> CP;const int M = 1 << 20;const double Pi = acos(-1.0);CP a[M], c[M], w[M], temp[M];long long ans, bb[M], sum[N];int n, root, size, tot, cnt, cnt2, po, pr[N], to[N << 1], nxt[N << 1], head[N], sz[N], dis[N], f[N];bool vis[N], check[N];inline int read(int &t) {    int f = 1;char c;    while (c = getchar(), c < '0' || c > '9') if (c == '-') f = -1;    t = c - '0';    while (c = getchar(), c >= '0' && c <= '9') t = t * 10 + c - '0';    t *= f;}void add(int u, int v) {    to[tot] = v, nxt[tot] = head[u], head[u] = tot++;    to[tot] = u, nxt[tot] = head[v], head[v] = tot++;}void fft(CP* p, int deep, int flag) {    if (deep == po) return;    int step = 1 << deep;    fft(p, deep + 1, flag);    fft(p + step, deep + 1, flag);    int num= 1 << (po - deep);    int ss = 0, half = num / 2;    CP a,b;    for (int i = 0; i < half; ++i) {        a = p[ss];        b = p[ss + step];        if (!flag)  b *= w[i << deep];        else b /= w[i << deep];        temp[i] = a + b;        temp[i + half] = a - b;        ss += 2 * step;    }    for (int i = 0; i < num; ++i)   p[i * step] = temp[i];    return;}void getroot(int u, int fa) {    f[u] = 0, sz[u] = 1;    for (int i = head[u], v; ~i; i = nxt[i]) {        v = to[i];        if (v != fa && !vis[v]) {            getroot(v, u);            sz[u] += sz[v];            f[u] = max(f[u], sz[v]);        }    }    f[u] = max(f[u], size - sz[u]);    if (f[u] < f[root]) root = u;}void dfs(int u, int fa, int d) {    dis[cnt++] = d;    for (int i = head[u], v; ~i; i = nxt[i]) {        v = to[i];        if (v != fa && !vis[v]) dfs(v, u, d + 1);    }}long long calc(int u, int d) {    int mx = 0;    long long t = 0;    cnt = 0;    dfs(u, 0, d);    for (int i = 0; i < cnt; ++i)   ++sum[dis[i]], mx = max(mx, dis[i]);    int l = 1;    po = 0;    while ((mx + 1) * 2 > l)    l <<= 1, ++po;    for (int i = 0; i < l; ++i) w[i] = CP(cos(2 * Pi * i / l), sin(2 * Pi * i / l));    for (int i = 0; i <= mx; ++i)   a[i] = CP(sum[i], 0.0);    for (int i = mx + 1; i < l; ++i)    a[i] = CP(0.0, 0.0);    fft(a, 0, 0);    for (int i = 0; i < l; ++i) c[i] = a[i] * a[i];    fft(c, 0, 1);    for (int i = 0; i < l; ++i) bb[i] = (long long)round(c[i].real() / l);    l = 2 * mx;    for (int i = 0; i < cnt; ++i)   --bb[dis[i] << 1];    for (int i = 0; i <= l; ++i)    bb[i] >>= 1;    for (int i = 1; i <= cnt2; ++i) t += bb[pr[i]];    for (int i = 0; i <= mx; ++i)   sum[i] = 0;    for (int i = 0; i <= l; ++i)    bb[i] = 0;    return t;}void gao(int u) {    f[root = 0] = size;    getroot(u, 0);    ans += calc(root, 0);    vis[root] = 1;    for (int i = head[root], v; ~i; i = nxt[i]) {        v = to[i];        if (!vis[v]) {            ans -= calc(v, 1);            size = sz[v];            gao(v);        }    }}void init() {    read(n);    memset(head, -1, sizeof(head));    for (int i = 2; i <= n; ++i) {        if (!check[i])  pr[++cnt2] = i;        for (int j = 1; j <= cnt2 && i * pr[j] <= n; ++j) {            check[i * pr[j]] = 1;            if (i % pr[j] == 0) break;        }    }    for (int i = 1, x, y; i < n; ++i) {        read(x), read(y);        add(x, y);    }    size = n;}int main() {    init();    gao(1);    printf("%.8lf\n", 2.0 * ans / n / (n - 1));    return 0;}
0 0