HDU4609 3-idiots

来源:互联网 发布:淘宝修图师招聘 编辑:程序博客网 时间:2024/05/16 04:18

多校那时候我还年轻,没能理解FFT,只知道它能够进行信号在时域和频域上的转换。

现在学习了FFT在多项式乘法上的应用,感觉对FFT的理解更深了一层。

要搞懂FFT,首先要理解卷积。


先抄一下wikipedia

函数 f 与 g 的卷积记作f * g,它是其中一个函数翻转并平移后与另一个函数的乘积的积分,是一个对平移量的函数。

(f * g )(t) = \int f(\tau) g(t - \tau)\, d\tau

积分区间取决于 f 的定义域与 g 的定义域。

对于定义在离散域的函数,卷积定义为

(f  * g)[m] = \sum_n {f[n] g[m - n]}


理解了卷积之后,我们可以把FFT先简单的理解为求卷积的O(nlogn)算法。

然后进入正题,本题中我们已知每一种长度的棒子各有多少个,换个思路想想,我们就知道了一个多项式的各项系数,把这个多项式与其自身做快速傅里叶变换,就可以得到新多项式的各项系数,对于新系数ki',有k[i]' = sigma( k[j] + k[i-j] ),而这里的系数正是由两根棒子加起来的长度的种数,假设用cnt[i]表示和的长度为i的棒子组合有多少种。


求出数组cnt[]之后,这个题目已经做了一半了,下面是计数问题。


(1) 首先去掉cnt里重复的部分,一根棒子不能与自己组合,所以有cnt[a[i] * 2] --

(2) 任意两个棒子组合的顺序我们不需要考虑,而实际上他们算了两次,所以有cnt[i] /=2

(3) 枚举所有棒子,假设这根棒子a[i]是组成三角形其中最长的那根,我们首先求出另外两根的长度和比它的长的总种数,即sigma(cnt[j])   (a[i]<j<=crest*2,crest是最长的棒子长度)

(4) 另外两根中,这根被枚举的棒子不能再出现了,所以要减去n-1

(5)  另外两根中,可能有一根长度大于a[i],另一根小于a[i],所以要减去(n-i) * (i-1)

(6) 另外两根中,可能两根长度都大于a[i],再减去(n-i) * (n-i-1) / 2


计算下所有可能的种数 tot = n * (n-1) * (n-2) / 6

把统计出的答案除以tot就行了。


注意:cnt 和 tot 都有可能超过int的范围,所以应该用long long 



贴代码:

#include <cstdio>#include <cstring>#include <cctype>#include <cstdlib>#include <ctime>#include <climits>#include <cmath>#include <iostream>#include <string>#include <vector>#include <set>#include <map>#include <list>#include <queue>#include <stack>#include <deque>#include <algorithm>using namespace std;typedef long long ll;const double PI = acos(-1.0);const int maxn = 100010;struct Complex{    double r, i;    Complex(double r=0, double i=0) : r(r), i(i) {}    Complex operator + (const Complex &o) {return Complex(r+o.r, i+o.i);}    Complex operator - (const Complex &o) {return Complex(r-o.r, i-o.i);}    Complex operator * (const Complex &o) {return Complex(r*o.r-i*o.i, r*o.i+o.r*i);}};void brc(Complex *y, int len){    for (int i=1, j=len/2; i < len-1; i++)    {        if (i < j) swap(y[i], y[j]);        int k = len/2;        while (j >= k) { j -= k; k /= 2; }        if (j < k) j += k;    }}void fft(Complex *y, int len, int on){    brc(y, len);    for (int i = 2; i <= len; i <<= 1)    {        Complex wn(cos(on*2.0*PI/i), sin(on*2.0*PI/i));        for (int j = 0; j < len; j += i)        {            Complex w(1.0, 0.0);            for (int k = j; k < j+i/2; k++)            {                Complex u = y[k];                Complex t = w*y[k+i/2];                y[k] = u + t;                y[k+i/2] = u - t;                w = w * wn;            }        }    }    if (on == -1)        for (int i = 0; i < len; i++)            y[i].r /= len;}int T,n;int a[maxn];ll cnt[maxn*4], sum[maxn*4], ans, tot;Complex x1[maxn*4];int main(){    scanf("%d", &T);    while (T--)    {        scanf("%d", &n);        memset(cnt, 0, sizeof(cnt));        int len = 1, crest = 0;        for (int i = 1; i <= n; i++)        {            scanf("%d", &a[i]);            if (a[i] > crest) crest = a[i];            cnt[a[i]]++;        }        crest++;        while (len < crest*2) len <<= 1;        for (int i = 0; i < len; i++) x1[i] = Complex(cnt[i], 0.0);        fft(x1, len, 1);        for (int i = 0; i < len; i++) x1[i] = x1[i] * x1[i];        fft(x1, len, -1);        for (int i = 0; i < len; i++) cnt[i] = (ll)(x1[i].r + 0.5);        len = crest * 2;        for (int i = 1; i <= n; i++) cnt[a[i]*2]--;        for (int i = 1; i <= len; i++) cnt[i] /= 2;        sum[0] = 0;        for (int i = 1; i <= len; i++) sum[i] = sum[i-1] + cnt[i];        ans = 0;        for (int i = 1; i <= n; i++)        {            ans += sum[len] - sum[a[i]];            ans -= n - 1;            ans -= (ll)(i - 1) * (n - i);            ans -= (ll)(n - i) * (n - i - 1) / 2LL;        }        tot = (ll)n * (n - 1) * (n - 2) / 6LL;        printf("%.7lf\n", (double)ans / tot);    }return 0;}


原创粉丝点击