CodeChef Arithmetic Progressions (分块FFT)

来源:互联网 发布:淘宝无锡电动车怎么样 编辑:程序博客网 时间:2024/06/07 17:35
题意:给出 A 1 ,A 2 ,...,A N 统计满足:
▶ i < j < k
▶ A i + A k = 2A j
的 (i,j,k) 数量。

(N ≤ 30000,A i ≤ 10 5 )

分块,把长度为n的序列分成sqrt(n)块长度为sqrt(n)的序列,然后遍历每一个块分三种情况:

1)三个都在同一个块里面:

暴力枚举后两个,每次维护前面的数的个数,复杂度O(sqrt(n)*n)

2)两个在同一块里面:

暴力枚举块中的两个,维护块前的数的个数和块后的数的个数,复杂度O(sqrt(n)*n)

3)一个在块中,一个在块前,一个在块后

块前块后做fft,枚举块中的数,复杂度O(nlgn*sqrt(n)).

#include <iostream>#include <cstdio>#include <cstring>#include <algorithm>#include <cmath>using namespace std;#define pi acos (-1)#define maxn 151111struct plex {    double x, y;    plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}    plex operator + (const plex &a) const {        return plex (x+a.x, y+a.y);    }    plex operator - (const plex &a) const {        return plex (x-a.x, y-a.y);    }    plex operator * (const plex &a) const {        return plex (x*a.x-y*a.y, x*a.y+y*a.x);    }};void change (plex y[] , int len) {    for (int i = 1 , j = len / 2 ; i < len -1 ; i ++) {        if (i < j) swap(y[i] , y[j]);        int k = len / 2;        while (j >= k) {            j -= k;            k /= 2;        }        if(j < k) j += k;    }}void fft(plex y[],int len,int on){    change(y,len);    for(int h = 2; h <= len; h <<= 1)    {        plex wn(cos(on*2*pi/h),sin(on*2*pi/h));        for(int j = 0;j < len;j+=h)        {            plex w(1,0);            for(int k = j;k < j+h/2;k++)            {                plex u = y[k];                plex t = w*y[k+h/2];                y[k] = u+t;                y[k+h/2] = u-t;                w = w*wn;            }        }    }    if(on == -1)        for(int i = 0;i < len;i++)            y[i].x /= len;}int num[maxn], num2[maxn];long long sum[maxn];int a[maxn];plex x[maxn], y[maxn];int n;int main () {    //freopen ("in.txt", "r", stdin);    scanf ("%d", &n);    for (int i = 0; i < n; i++) {        scanf ("%d", &a[i]);    }    long long ans = 0;    int block = sqrt (n), len = n/block;    if (len*block != n)        block++;    //1 都在当前块    for (int t = 0; t < block; t++) {        memset (num, 0, sizeof num);        for (int i = min ((t+1)*len-1, n-1); i >= t*len; i--) {            for (int j = t*len; j < i; j++)                num[a[j]]++;            for (int j = i-1; j >= t*len; j--) {                num[a[j]]--;                int cur = 2*a[j]-a[i];                if (cur > 0)                    ans += num[cur];            }        }    }    //两个在当前块    for (int t = 0; t < block; t++) {        memset (num, 0, sizeof num);        //第三个在前面的块中        for (int i = 0; i < t*len; i++)            num[a[i]]++;        for (int i = t*len; i < n && i < (t+1)*len; i++) {            for (int j = i+1; j < n && j < (t+1)*len; j++) {                int cur = a[i]*2-a[j];                if (cur > 0)                    ans += num[cur];            }        }        //第三个在后面的块中        memset (num, 0, sizeof num);        for (int i = (t+1)*len; i < n; i++)            num[a[i]]++;        for (int i = t*len; i < n && i < (t+1)*len; i++) {            for (int j = i+1; j < n && j < (t+1)*len; j++) {                int cur = a[j]*2-a[i];                if (cur > 0)                    ans += num[cur];            }        }    }    //只有一个在当前块中 一个在前面 一个在后面    for (int t = 0; t < block; t++) {        int cnt1 = 0, cnt2 = 0, Max = 0;        memset (num, 0, sizeof num);        for (int i = 0; i < t*len; i++) {            num[a[i]]++;            Max = max (Max, a[i]);        }        memset (num2, 0, sizeof num2);        for (int i = (t+1)*len; i < n; i++) {            num2[a[i]]++;            Max = max (Max, a[i]);        }        int l = 1;        Max++;        while (l < 2*Max)            l <<= 1;        for (int i = 0; i < l; i++) {            x[i] = plex (num[i], 0);        }        for (int i = 0; i < l; i++) {            y[i] = plex (num2[i], 0);        }        fft (x, l, 1);        fft (y, l, 1);        for (int i = 0; i < l; i++)            x[i] = x[i]*y[i];        fft (x, l, -1);        memset (sum, 0, sizeof sum);        for (int i = 1; i < l; i++) {            sum[i] = (long long) (x[i].x+0.5);        }        for (int i = t*len; i < (t+1)*len && i < n; i++) {            ans += sum[2*a[i]];        }    }    printf ("%lld\n", ans);    return 0;}


0 0
原创粉丝点击