FFT什么的

来源:互联网 发布:weka 聚类算法 编辑:程序博客网 时间:2024/04/25 02:35

  这里只有公式&做法,没有复杂的证明(其实是因为弱鸡yww不会)

  参考自国家集训队论文&各个博客

多项式

​  一个以x为变量的多项式定义在一个代数域F上,将函数A(x)表示为形式和:

A(x)=j=0n1ajxj

我们称a0,a1,,an1为多项式的系数,所有系数都属于数域F,典型的情形是负数集合C

  如果一个多项式的最高次的非零系数是ak,则称A(x)的次数是k。任何严格大于一个多项式次数的整数都是该多项式的次数界。因此,对于次数界为n的多项式C(x),其次数可以是0~n1之间的任何整数,包括0n1

​  我们在多项式上可以定义很多不同的运算。

多项式加法

​  如果A(x)B(x)是次数界为n的多项式,那么他们的和也是一个次数界为n的多项式C(x)。对于所有属于定义域的x,都有C(x)=A(x)+B(x)。也就是说,若

A(x)=j=0n1ajxjB(x)=j=0n1bjxj


C(x)=j=0n1cjxj

其中
cj=aj+bj

​  例如,如果
A(x)=6x3+7x210x+9,B(x)=2x3+4x5


C(x)=4x3+7x26x+4

多项式乘法

​  如果A(x)是次数界为n的多项式,B(x)是次数界为m的多项式,那么他们的乘积是一个次数界为n+m的多项式C(x)。其中

cj=k=0jakbjk

​  例如,如果
A(x)=6x3+7x210x+9,B(x)=2x3+4x5

​  则
C(x)=12x614x5+44x420x375x2+86x45

多项式的表示

系数表达

​  对一个次数界为n的多项式A(x)=n1j=0ajxj而言,其系数表达式一个由系数组成得到向量a=(a0,a1,,an1)

​  我们可以用秦久韶算法在O(n)的时间内求出多项式在给定点x0的值,即求值运算:

A(x0)=a0+x0(a1+a0(a2++x0(an1+x0(an1)))

​  类似的,对于两个分别用系数向量a=(a0,a1,,an1),b=(b0,b1,,bn1)表示的多项式进行相加时,所需的时间是O(n)。我们只用输出系数向量c=(c0,c1,,cn1),其中ci=ai+bi

​  现在来考虑两个用系数形式表达的次数界为n的多项式A(x),B(x)的乘法运算,所需要的时间是O(n2)。系数向量c也称为输入向量a,b的卷积。c=ab

点值表达

​  一个次数界为n的多项式的点值表达就是一个有n个点值对所组成的集合。

{(x0,y0),(x1,y1),,(xn1,yn1)}

使得对k=0,1,,n1,所有xk各不相同且yk=A(xk)

​  一个多项式可以有很多不同的点值表达,因为可以采用n个不同的点构成的集合作为这种表示方法的基。

​  朴素的求值是O(n2)的。

​  求值的逆称为插值。当插值多项式的次数界等于已知的点值对的数目时,插值才是明确的。

​  我们可以在用高斯消元在O(n3)内插值,也可以用拉格朗日插值在O(n2)内插值。

​  以上求值和插值可以将多项式的系数表达和点值表达进行相互转化,上面给出的算法的时间复杂度是O(n2),但我们可以巧妙地选取xk来加速这一过程,使其运行时间变为O(nlogn)

​  对于许多多项式相关的操作,点值表达式很便利的。

​  对于加法,如果C(x)=A(x)+B(x)。给定A的点值表达

{(x0,y0),(x1,y1),,(xn1,yn1)}

B的点值表达
{(x0,y0),(x1,y1),,(xn1,yn1)}

(注意,AB在相同的n个位置求值),则C的点值表达是
{(x0,y0+y0),(x1,y1+y1),,(xn1,yn1+yn1)}

因此,对两个点值形式表示的次数界为n的多项式相加,时间复杂度是O(n)

​  类似的,如果C(x)=A(x)B(x),我们需要2n个点值对才能插出C。给定A的点值表达

{(x0,y0),(x1,y1),,(x2n1,y2n1)}

B的点值表达
{(x0,y0),(x1,y1),,(x2n1,y2n1)}

(注意,AB在相同的2n个位置求值),则C的点值表达是
{(x0,y0y0),(x1,y1y1),,(x2n1,y2n1y2n1)}

因此,对两个点值形式表示的次数界为n的多项式相乘,时间复杂度是O(n)

​  最后,我们考虑一个采用点值表达的多项式,如何求其在某个新点上的值。最简单的方法是把该多项式转成系数形式表达,然后在新点处求值。

系数形式表示的多项式的快速乘法

​  如果我们选n次单位复数根作为求值点,我们可以在O(nlogn)内求值和插值。我们先在对这两个多项式A,B求值之前添加n0,使其次数界加倍为2n。现在我们采用“2n次单位复数根”作为求值点。

DFT&FFT&IDFT

单位复数根

​  n次单位复数根是满足wn=1的复数wn次单位复数根恰好有n个,对于k=0,1,,n1,这些根是e2πiknwn=e2πin称为主n次单位根,所有其他n次单位复数根都是wn的幂次。这nn次单位复数根在乘法意义下形成了一个群,即wjnwkn=w(j+k)mod nn,而且这nn次单位复数根均匀分布在以复平面的原点为圆心的单位半径的圆周上。(图片from zjt)

这里写图片描述

​  消去引理:对任何整数n0,k0,d>0

wdkdn=wkn

DFT

​  回顾一下,我们希望计算次数界为n的多项式A(x)w0n,w1n,,wn1n处的值(即在nn次单位复数根处)。对于k=0,1,,n1,定义结果yk

yk=A(wkn)=j=0n1ajwkjn

向量y=(y0,y1,,yn1)就是系数向量a的离散傅里叶变换(DFT),我们也记为y=DFTn(a)

FFT

​  利用单位复数根的特殊性质,我们可以在O(nlogn)内计算出DFTn(a)。这里假设n2的幂。

  FFT利用了分治策略。

  我们令a=(a0,a1,,an1),a1=(a0,a2,,an2),a2=(a1,a3,,an1)

  对于k<n2有:

yk=A(wkn)=j=0n1ajwkjn=j=0n21a2jw2kjn+j=0n21a2j+1w2kj+kn=j=0n21a2jw2kjn+wknj=0n21a2j+1w2kjn=j=0n21a1jwkjn2+wknj=0n21a2jwkjn2=y1k+wkny2k

  对于kn2有:
yk=A(wkn)=j=0n1ajwkjn=j=0n21a2jw2kjn+j=0n21a2j+1w2kj+kn=j=0n21a2jw2kjn+wknj=0n21a2j+1w2kjn=j=0n21a1jwkjn2+wknj=0n21a2jwkjn2=j=0n21a1jw(kn2)jn2+wknj=0n21a2jw(kn2)jn2=y1kn2+wkny2kn2=y1kn2wkn2ny2kn2

  这样我们把y1,y2合并为y的时间复杂度是O(n)。所以总的时间复杂度是
T(n)=2T(n2)+O(n)=O(nlogn)

IDFT

​  通过推导公式,我们得到:

ak=1nj=0n1yjwkjn

​  所以我们可以用类似FFT的方法在O(nlogn)内求出IDFTn(y)

多项式乘法

​  我们可以在O(n)内补0O(nlogn)内求值,O(n)内点值乘法,O(nlogn)内插值。所以我们可以在O(nlogn)内求出ab

ab=IDFT2n(DFT2n(a)DFT2n(b))

蝶形运算

  我们把由y1k,y2k,wkn得到yk,yk+n2的过程称为蝴蝶操作。

​  我们发现,递归时a是长这样的:

0   1   2   3   4   5   6   70   2   4   6 | 1   3   5   70   4 | 2   6 | 1   5 | 3   70 | 4 | 2 | 6 | 1 | 5 | 3 | 7

  总的蝶形运算是长这样的:
  
  这里写图片描述

​  可以发现,最后ai是原来的arev(i)。所以我们可以交换ai,arev(i),然后一层层来做。这样可以减小常数。

NTT

​  在某些时候,我们需要求模p意义下的卷积。

​  先求出p的原根g,可以发现,gp1nwn的性质类似。所以我们可以用gp1n来代替wn

时间上的优化

​  令tj=(aj+bj)+(ajbj)i,S=T×T

​  sj的实部为

k=0j(ak+bk)2(akbk)2=k=0j4akbk=4k=0jakbk

  这样我们就可以求出S=T×T,然后把sj除以4

  这个方法可以把3次DFT改成2次DFT。

多项式求导

  给定A(x)=i0aixi,定义A(x)的形式导数为

A(x)=i1iaixi1

多项式积分

  给定A(x)=i0aixi,则

A(x)=i1ai1ixi

多项式求逆

​  多项式A(x)存在乘法逆元的充要条件是A(x)的常数项存在乘法逆元。

​  下面介绍一个O(n log n)计算乘法逆元的算法,它的本质是牛顿迭代法

​  首先求出A(x)常数项的逆元b,令B(x)的初始值为b

​  假设已求出满足

A(x)B(x)1 (mod xn)

B(x),则
A(x)B(x)1(A(x)B(x)1)2A(x)(2B(x)B(x)2A(x))0 (mod xn)0 (mod x2n)1 (mod x2n)

​  我们可以用O(n log n)的时间计算出2B(x)B(x)2A(x),并将它赋值给B(x)进行下一次迭代。每迭代一次,B(x)的有效项数n都会增加一倍。于是该算法的时间复杂度为
T(n)=T(n/2)+O(nlogn)=O(nlogn)

多项式开根

  已知A(x),求B(x)使得

B(x)2A(x) (mod xn)

  可能可以用ln&exp来算。

  先求出A(x)常数项的平方根b(可以用二次剩余来算),令B(x)的初始值为b

  假设已求出满足

B(x)2A(x) (mod xn)

B(x),则
B(x)2A(x)(B(x)2A(x))2B(x)42B(x)2A(x)+A(x)2B(x)4+2B(x)2A(x)+A(x)2(B(x)2+A(x))2(B(x)2+A(x)2B(x))20 (mod xn)0 (mod x2n)0 (mod x2n)4B(x)2A(x) (mod x2n)(2B(x))2A(x) (mod x2n)A(x) (mod x2n)

  我们可以在O(nlogn)内算出B(x)2+A(x)2B(x)=B(x)2+A(x)2B(x),并把它赋值给B(x)

  时间复杂度:O(nlogn)

多项式ln

  给定形式幂级数A(x)=i1aixi,定义

ln(1A(x))=i1A(x)ii

  给定多项式A(x)=1+i1aixi,令
B(x)=ln(A(x))


B(x)=A(x)A(x)

  只需要求出A(x)的乘法逆元,就可以求出ln(A(x))

多项式exp

  给定形式幂级数A(x)=i1aixi,定义

exp(A(x))=i0A(x)ii!

  令f(x)=eA(x),可得到一个关于f(x)的方程
g(f(x))=ln(f(x))A(x)=0

  考虑用牛顿迭代解这一方程。首先f(x)的常数项是容易确定的(就是1)。

  设以求得f(x)的前nf0(x),即

f(x)f0(x)   (mod   xn)

  作泰勒展开得
0=g(f(x))=g(f0(x))+g(f0(x))(f(x)f0(x))     (mod   x2n)


f(x)f0(x)g(f0(x))g(f0(x))    (mod   x2n)

  把上面那个式子带入得
f(x)=f0(x)ln(f0(x))A(x)1f0(x)=f0(x)(1ln(f0(x))+A(x))

  时间复杂度:O(nlogn)

多项式除法

​  给你A(x),B(x),求两个多项式D(x),R(x)满足

A(x)=D(x)B(x)+R(x)

​  若A(x)是一个n阶多项式,则
AR(x)=xnA(1x)

  举个例子:比如说
A(x)=x3+2x2+3x+4AR(x)=1+2x+3x2+4x3

​  相当于把A(x)的系数反转。

  我们设A(x)n阶多项式,B(x)m阶多项式,D(x)nm阶多项式,R(x)m1阶多项式。我们把上个式子的x1x,然后全部乘上xn

xnA(1x)=xnmD(1x)xmB(1x)+xnm+1xm1R(1x)AR(x)=DR(x)BR(x)+xnm+1RR(x)

  然后我们把这个式子放在模xnm+1意义下,得到
AR(x)=DR(x)BR(x) (mod xnm+1)DR(x)=AR(x)(BR(x))1 (mod xnm+1)

  因为D(x)的次数是nm,所以不会受模意义的影响。

  然后把D(x)带入到原来的式子中,就可以算出R(x)了。

  时间复杂度:O(nlogn)

多点求值

  给你一个多项式A(x)n个点x0,x1,,xn1,求这个多项式在这n个点处的值,即求A(x0),A(x1),,A(xn1)

  考虑一个简单的做法:构造Bi(x)=xxi,Ci(x)=A(x) mod Bi(x),那么Bi(xi)=0。所以A(xi)=Ci(xi)。但是计算Bi(x)Ci(x)O(n)的,必须加速这个过程。

  设当前求值的点为X={x0,x1,,xn1},我们可以把这n个点分为两半:

X0={x0,x1,,xn21}X1={xn2,xn2+1,,xn1}

  构造多项式
B0=i=0n21(xxi)B1=i=n2n1(xxi)A0=A mod B0A1=A mod B1

  那么当xX0A(x)=A0(x),可以递归计算。当xX1时同理。

  每一层计算B0,B1,A0,A1的时间复杂度都是O(nlogn)

  总的时间复杂度就是

T(n)=2T(n2)+O(nlogn)=O(nlog2n)

快速插值

模板

#include<cstdio>#include<cstring>#include<algorithm>#include<cstdlib>#include<ctime>#include<utility>#include<cmath>#include<functional>using namespace std;typedef long long ll;typedef unsigned long long ull;typedef pair<int,int> pii;typedef pair<ll,ll> pll;void sort(int &a,int &b){    if(a>b)        swap(a,b);}void open(const char *s){#ifndef ONLINE_JUDGE    char str[100];    sprintf(str,"%s.in",s);    freopen(str,"r",stdin);    sprintf(str,"%s.out",s);    freopen(str,"w",stdout);#endif}int rd(){    int s=0,c;    while((c=getchar())<'0'||c>'9');    do    {        s=s*10+c-'0';    }    while((c=getchar())>='0'&&c<='9');    return s;}int upmin(int &a,int b){    if(b<a)    {        a=b;        return 1;    }    return 0;}int upmax(int &a,int b){    if(b>a)    {        a=b;        return 1;    }    return 0;}const ll p=998244353;const ll g=3;ll fp(ll a,ll b){    ll s=1;    while(b)    {        if(b&1)            s=s*a%p;        a=a*a%p;        b>>=1;    }    return s;}const int maxn=600000;ll inv[maxn];namespace ntt{    ll w1[maxn];    ll w2[maxn];    int rev[maxn];    int n;    void init(int m)    {        n=1;        while(n<m)            n<<=1;        int i;        for(i=2;i<=n;i<<=1)        {            w1[i]=fp(g,(p-1)/i);            w2[i]=fp(w1[i],p-2);        }        rev[0]=0;        for(i=1;i<n;i++)            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));    }    void ntt(ll *a,int t)    {        int i,j,k;        ll u,v,w,wn;        for(i=0;i<n;i++)            if(rev[i]<i)                swap(a[i],a[rev[i]]);        for(i=2;i<=n;i<<=1)        {            wn=(t==1?w1[i]:w2[i]);            for(j=0;j<n;j+=i)            {                w=1;                for(k=j;k<j+i/2;k++)                {                    u=a[k];                    v=a[k+i/2]*w%p;                    a[k]=(u+v)%p;                    a[k+i/2]=(u-v)%p;                    w=w*wn%p;                }            }        }        if(t==-1)        {            u=fp(n,p-2);                for(i=0;i<n;i++)                a[i]=a[i]*u%p;        }    }    ll x[maxn];    ll y[maxn];    ll z[maxn];    void copy_clear(ll *a,ll *b,int m)    {        int i;        for(i=0;i<m;i++)            a[i]=b[i];        for(i=m;i<n;i++)            a[i]=0;    }    void copy(ll *a,ll *b,int m)    {        int i;        for(i=0;i<m;i++)            a[i]=b[i];    }    void mul(ll *a,ll *b,ll *c,int m)    {        init(m<<1);        copy_clear(x,a,m);        copy_clear(y,b,m);        ntt(x,1);        ntt(y,1);        int i;        for(i=0;i<n;i++)            x[i]=x[i]*y[i]%p;        ntt(x,-1);        copy(c,x,m);    }    void inverse(ll *a,ll *b,int m)    {        if(m==1)        {            b[0]=fp(a[0],p-2);            return;        }        inverse(a,b,m>>1);        init(m<<1);        copy_clear(x,a,m);        copy_clear(y,b,m>>1);        ntt(x,1);        ntt(y,1);        int i;        for(i=0;i<n;i++)            x[i]=y[i]*(2-x[i]*y[i]%p)%p;        ntt(x,-1);        copy(b,x,m);    }    ll c[maxn],d[maxn],e[maxn],f[maxn];    void sqrt(ll *a,ll *b,int m)    {        if(m==1)        {            if(a[0]==1)                b[0]=1;            else if(a[0]==0)                b[0]=0;            else                //我也不会                ;            return;        }        sqrt(a,b,m>>1);//      copy_clear(c,b,m>>1);        int i;        for(i=m;i<m<<1;i++)            b[i]=0;        inverse(b,d,m);        init(m<<1);        for(i=m;i<m<<1;i++)            b[i]=d[i]=0;        ll inv2=fp(2,p-2);        copy_clear(x,a,m);        ntt(x,1);        ntt(d,1);        for(i=0;i<n;i++)            x[i]=x[i]*d[i]%p;        ntt(x,-1);        for(i=0;i<m;i++)            b[i]=((b[i]+x[i])%p*inv2)%p;    }    void derivative(ll *a,ll *b,int m)    {        int i;        for(i=0;i<m-1;i++)            b[i]=(i+1)*a[i+1]%p;        b[m-1]=0;    }    void differential(ll *a,ll *b,int m)    {        int i;        for(i=m-1;i>=1;i--)            b[i]=a[i-1]*inv[i]%p;        b[0]=0;    }    void ln(ll *a,ll *b,int m)    {        static ll c[maxn],d[maxn];        derivative(a,c,m);        inverse(a,d,m);        init(m<<1);        int i;        for(i=m;i<n;i++)            c[i]=d[i]=0;        ntt(c,1);        ntt(d,1);        for(i=0;i<n;i++)            c[i]=c[i]*d[i]%p;        ntt(c,-1);        differential(c,b,m);    }    void exp(ll *a,ll *b,int m)    {        if(m==1)        {            b[0]=1;            return;        }        exp(a,b,m>>1);        int i;        for(i=m>>1;i<m;i++)            b[i]=0;        ln(b,y,m);        init(m<<1);        copy_clear(x,a,m);        x[0]++;        for(i=0;i<m;i++)            x[i]=(x[i]-y[i])%p;        copy_clear(y,b,m);        ntt(x,1);        ntt(y,1);        for(i=0;i<n;i++)            x[i]=x[i]*y[i]%p;        ntt(x,-1);        copy(b,x,m);    }    void module(ll *a,ll *b,ll *c,int n1,int n2)    {        int k=1;        while(k<=n1-n2+1)            k<<=1;        int i;        for(i=0;i<=n1;i++)            d[i]=a[i];        for(i=0;i<=n2;i++)            e[i]=b[i];        reverse(d,d+n1+1);        reverse(e,e+n2+1);        for(i=n1-n2+1;i<k<<1;i++)            d[i]=e[i]=0;        inverse(e,f,k);        for(i=n1-n2+1;i<k<<1;i++)            f[i]=0;        init(k<<1);        ntt::ntt(d,1);        ntt::ntt(f,1);        for(i=0;i<n;i++)            e[i]=d[i]*f[i]%p;        ntt::ntt(e,-1);        for(i=0;i<=n1-n2;i++)            c[i]=e[i];        reverse(c,c+n1-n2+1);    }};ll b[maxn];ll a[maxn];ll c[maxn];void get(ll *a,int n){    int i;    for(i=0;i<n;i++)        a[i]=rand();}int main(){//  freopen("fft.txt","w",stdout);//  srand(time(0));//  int n=262144;//  int bg,ed;//  int i;//  int times=100,j;//  double s,s1;//  inv[0]=inv[1]=1;//  for(i=2;i<=n;i++)//      inv[i]=-(p/i)*inv[p%i]%p;//  s=0;//  for(j=1;j<=times;j++)//  {//      get(a,n);//      bg=clock();//      ntt::init(n);//      ntt::ntt(a,1);//      ed=clock();//      s+=double(ed-bg)/CLOCKS_PER_SEC;//  }//  printf("ntt :%.10lf\n",s/times);//  s1=s;//  s=0;//  for(j=1;j<=times;j++)//  {//      get(a,n);//      get(b,n);//      bg=clock();//      ntt::mul(a,b,c,n);//      ed=clock();//      s+=double(ed-bg)/CLOCKS_PER_SEC;//  }//  printf("mul :%.10lf %.10lf\n",s/times,s/s1);//  s=0;//  for(j=1;j<=times;j++)//  {//      get(a,n);//      bg=clock();//      ntt::inverse(a,b,n);//      ed=clock();//      s+=double(ed-bg)/CLOCKS_PER_SEC;//  }//  printf("inv :%.10lf %.10lf\n",s/times,s/s1);//  s=0;//  for(j=1;j<=times;j++)//  {//      get(a,n);//      a[0]=1;//      bg=clock();//      ntt::sqrt(a,b,n);//      ed=clock();//      s+=double(ed-bg)/CLOCKS_PER_SEC;//  }//  printf("sqrt:%.10lf %.10lf\n",s/times,s/s1);//  s=0;//  for(j=1;j<=times;j++)//  {//      get(a,n);//      a[0]=1;//      bg=clock();//      ntt::ln(a,b,n);//      ed=clock();//      s+=double(ed-bg)/CLOCKS_PER_SEC;//  }//  printf("ln  :%.10lf %.10lf\n",s/times,s/s1);//  s=0;//  for(j=1;j<=times;j++)//  {//      get(a,n);//      bg=clock();//      ntt::exp(a,b,n);//      ed=clock();//      s+=double(ed-bg)/CLOCKS_PER_SEC;//  }//  printf("exp :%.10lf %.10lf\n",s/times,s/s1);//  return 0;}