HDOJ 1402. A * B Problem Plus (FFT快速傅里叶变换)

来源:互联网 发布:mac os x10.10 iso 编辑:程序博客网 时间:2024/06/10 10:13

Problem Description
Calculate A * B.

Input
Each line will contain two integers A and B. Process to end of file.
Note: the length of each integer will not exceed 50000.

Output
For each case, output A * B in one line.

这是一道套用FFT模板的题目,因为刚学习了FFT算法知识,就拿来练手。

对于两个多项式相乘问题,a0+a1x1+a2x2+...+an1xn1b0+b1x1+b2x2+...+bn1xn1,FFT可以通过求值和插值的方法,获得O(nlgn)的时间复杂度。

求值:对于函数A(x)=a0+a1x1+a2x2+...+an1xn1,求得n个不同的点值(xi,A(xi))

对函数公式进行变换

A(x)=a0+a1x1+a2x2+...+an1xn1=(a0+a2x2+a4x4+...+an2xn2)+(a1x1+a3x3+a5x5+...+an1xn1)=(a0+a2x2+a4x4+...+an2xn2)+x(a1+a3x2+a5x4+...+an1xn2)(assume that n is even)

试想一下,如果我们所要求的A(x1),A(x2)x21=x22,那么对于上式中的两个括号中的值,只需要计算一次即可,然后分别进行一次乘法和一次加法。对于求n个A(xi)的值,如果我们选取的n的xi,两两之间能够满足x2i=x2j,我们的计算量相当于是进行了折半。但还是不够。

如果这个过程能够递归地进行下去,也就是我们对a0+a2x2+a4x4+...+an2xn2a1+a3x2+a5x4+...+an1xn2也能够用同样的方式求得n/2个点的值(这里进行一个变量替换y=x2,便可以得到类似A(x)的规模更小的多项式,所以我们能进行递归),那么这个递归算法为T(n)=2T(n/2)+O(n),根据Master Theory,可以知道算法的复杂度为O(nlgn)。要使这个算法能够递归地进行下去,要达到的条件是在递归的每一层,我们总能够找到两两成对的x2i=x2j

要达到这种要求,就需要在复数空间对1开n次方根得到n个值w0n,w1n,w2n,...wn1n,到下一层递归时又可以继续下去。这里数学的推导相对复杂,对于复数的知识以及具体递归的处理可以参考《算法导论》FFT章节。

对两个多项式分别进行求值之后,进行点值乘法可以得到结果多项式上的n个点值(xi,C(xi))C(xi)=A(xi)B(xi)

插值:得到n个点值(xi,C(xi))后,求A(x)=c0+c1x1+c2x2+...+cn1xn1中n个系数值c0,c1,c2...cn1

从数学上推导过来,在求值(xk=wkn,A(xk))

A(xk)=j=0n1ajwkjn(k=0,1,...,n1)

在插值时
ck=1nj=0n1C(xj)wkjn(j=0,1,...,n1)

(具体的数学推导请参考其他资料)这样来看,插值和求值变为了类似的过程。

#include <cstdio>#include <cstring>#include <cmath>#include <complex>#include <algorithm>using namespace std;#define PI 3.14159265358979323846#define MAX_N 1 << 17                 // errorchar a[MAX_N], b[MAX_N];complex<double> A[MAX_N], B[MAX_N], temp[MAX_N];int res[MAX_N];void reverse_copy(char* a, complex<double>* A, int n, int k) {    // n >>= 1; if n == 1, then n = 0   // error    for (int i = 0; i < n / 2; i++)        swap(a[i], a[n - 1 - i]);    // n <<= 1; if n == 0, then n = 0    for (int i = 0; i < k; i++)        A[i] = (i < n) ? complex<double>(a[i] - '0') : complex<double>(0);}int rev(int k, int lg_n) {    int r = 0;    for (int i = 0; i < lg_n; i++) {        r <<= 1;        r |= (k & 1);        k >>= 1;    }    return r;}/*void bit_reverse_copy(complex<double>* A, int k) {    if (k == 1) { return; }    int lg_k = 0;    for (int i = 1; i < k; i <<= 1, lg_k++);    for (int i = 0; i < k; i++) temp[i] = A[i];    for (int i = 0; i < k; i++)        A[rev(i, lg_k)] = temp[i];}*/void bit_reverse_swap(complex<double>* A, int n) {    int lg_n = 0;    for (int i = 1; i < n; i <<= 1, lg_n++);    for (int k = 0; k < n; k++)        if (k < rev(k, lg_n))            swap(A[k], A[rev(k, lg_n)]);}void FFT(complex<double>* A, int n, int flag) {    // bit_reverse_copy(A, n);    bit_reverse_swap(A, n);    int s, j, k, t, st, lg_n = 1;    complex<double> w, u, v, w_n;    for (int i = 2; i < n; i <<= 1, lg_n++);    for (s = 0; s < lg_n; s++) {        int l = 1 << (s + 1);        w_n = complex<double>(cos(flag * 2 * PI / l),                              sin(flag * 2 * PI / l));        for (t = 0; t < n / l; t++) {            w = 1;            st = t * l;            for (j = 0; j < (1 << s) ; j++) {                u = A[st + j];                v = A[st + j + (1 << s)];                v *= w;                A[st + j] += v;                A[st + j + (1 << s)] = u - v;                w *= w_n;            }        }    }    // int s, j, k, st, lg_n = 1;    // complex<double> w, u, v, w_n, t;    // for (int i = 2; i < n; i <<= 1, lg_n++);    // for (s = 1; s <= lg_n; s++) {    //     int m = 1 << s;    //     w_n = complex<double>(cos(flag * 2 * PI / m), sin(flag * 2 * PI / m));    //     for (k = 0; k <= n - 1; k += m) {    //         w = 1;    //         for (j = 0; j <= m / 2 - 1; j++) {    //             t = w * A[k + j + m / 2];    //             u = A[k + j];    //             A[k + j] = u + t;    //             A[k + j + m / 2] = u - t;    //             w *= w_n;    //         }    //     }    // }    if (flag == -1) {          for (int i = 0; i < n; i++) {            A[i] /= complex<double>(n);        }    } }int main() {    while (~scanf("%s%s", a, b)) {                int n = strlen(a), m = strlen(b), t = n + m;        int k = 1;        for (; k < t; k <<= 1);        reverse_copy(a, A, n, k);        reverse_copy(b, B, m, k);        FFT(A, k, 1);        FFT(B, k, 1);        for (int i = 0; i < k; i++) {            A[i] *= B[i];        }        FFT(A, k, -1);        for (int i = 0; i < k; i++) {            res[i] = int(A[i].real() + 0.5);        }        for (int i = 0; i < k - 1; i++) {            if (res[i] >= 10) {                // res[i + 1] = res[i] / 10;    // error                res[i + 1] += res[i] / 10;                res[i] %= 10;            }        }        int j;        for (j = k - 1; j >= 0 && res[j] == 0; j--);        if (j < 0) {            printf("0");        } else {            for (; j >= 0; j--) putchar(res[j] + '0');        }        puts("");    }    return 0;}

贴上代码,以上代码参考了FFT 模板,基本上是在用这里的代码来debug,然后逐渐替换成自己的代码。

这道题我一开始认为每个数最大的长度为50000,两数相乘最多是100000,所以任意取了MAX_N 120000。但是因为我们总是要把两数的位数和补成2的次方,所以在计算过程中用的n一定是大于100000,会达到1<<17,所以WA了好久。

另一个是在这道题中,递归算法是无法通过的,只能使用非递归的实现。在非递归的实现中,有一个优化的地方是本来是要A[rev(k)]=ak,这样的话需用用另一个数组来辅助存储。但是因为k与rev(k)互为二进制位的位逆序,所以交换即可。

参考
1. http://www.cnblogs.com/Patt/p/5503322.html
2. 《算法导论》

0 0
原创粉丝点击