THUPC2017 I题 Sum

来源:互联网 发布:c语言标识符由什么组成 编辑:程序博客网 时间:2024/06/04 18:17

前言

今年跟着Jason和栋栋去了thupc,做完g题之后我就一直一边吃东西一边看着他们玩2333,最后还莫名有奖金(赚大了; )哈哈哈)

题目大意

给出长度为n的数组a[]

fk=i=1naik

求f[1..n]
1n4×105

解法

这题竟然是生成函数,我竟然没有去写!!(虽然好像很少写多项式求逆来着)
设g[i]表示这n个a[]中选出i个不同的x,所有方案中a[x]的乘积的和。
举个例子,如果n=3
那么g[1]=a[1]+a[2]+a[3]g[2]=a[1]a[2]+a[1]a[3]+a[2]a[3]g[3]=a[1]a[2]a[3]
然后可以发现这样一个规律:

f[1]=g[1]f[2]=f[1]×g[1]2×g[2]f[3]=f[2]×g[1]f[1]×g[2]+3×g[3]f[4]=f[3]×g[1]f[2]×g[2]+f[1]×g[3]4×g[4]...

设F(x)为f的生成函数,即
F(x)=i=1nf[i]xi

设生成函数G(x)为
G(x)=i=1nd[i]g[i]xi

其中d[i]满足:若i为奇数那么d[i]=1否则d[i]=-1
设生成函数T(x)为
T(x)=i=1nd[i]g[i]i

其中d[i]与G(x)中的一样
那么有:
F(x)=F(x)G(x)+T(x)

即:
F(x)=T(x)1G(x)

其中G(x)可以用分治FFT解决,要用到多项式求逆(第一次打啊)

代码

#include<iostream>#include<cstring>#include<cstdio>#include<algorithm>#include<cmath>#include<set>#include<map>#define fo(i,a,b) for(int i=a;i<=b;i++)#define fd(i,a,b) for(int i=a;i>=b;i--)using namespace std;typedef long long LL;typedef double db;int get(){    char ch;    while(ch=getchar(),(ch<'0'||ch>'9')&&ch!='-');    if (ch=='-'){        int s=0;        while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';        return -s;    }    int s=ch-'0';    while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';    return s;}const int MAXN = 1100010;const int mo = 998244353;const int p = 3;int a[MAXN],b[MAXN],c[MAXN];int g[MAXN],f[MAXN],g_nv[MAXN],t[MAXN];int bitr[MAXN];int mi[25];int N,L;int T,n;int A[MAXN],B[MAXN];int G,pm[MAXN];LL quickmi(LL x,LL tim){    LL ans=1;    while(tim){        if (tim%2)ans=ans*x%mo;        x=x*x%mo;        tim/=2;    }    return ans;}void prepare(){    fo(i,0,N-1){        bitr[i]=0;        fo(j,0,L-1)        if ((i&mi[j])>0)bitr[i]+=mi[L-1-j];    }    G=quickmi(p,(mo-1)/N);    pm[0]=1;    fo(i,1,N)pm[i]=1ll*pm[i-1]*G%mo;}inline int add(int a,int b){    return a+b>=mo?a+b-mo:a+b;}inline int dec(int a,int b){    return a<b?a+mo-b:a-b;}void DFT(int *a){    fo(i,0,N-1)if (i<bitr[i])swap(a[i],a[bitr[i]]);    for(int now=2;now<=N;now<<=1){        int half=now/2;        fo(i,0,half-1){            LL w=pm[N/now*i];            for(int j=i;j<N;j+=now){                LL l=a[j],r=w*a[j+half]%mo;                a[j]=add(l,r);                a[j+half]=dec(l,r);            }        }    }}void IDFT(int *a){    fo(i,0,N-1)if (i<bitr[i])swap(a[i],a[bitr[i]]);    for(int now=2;now<=N;now<<=1){        int half=now/2;        fo(i,0,half-1){            LL w=pm[N-N/now*i];            for(int j=i;j<N;j+=now){                LL l=a[j],r=w*a[j+half]%mo;                a[j]=add(l,r);                a[j+half]=dec(l,r);            }        }    }    LL nv=quickmi(N,mo-2);    fo(i,0,N-1)a[i]=nv*a[i]%mo;}bool sig;void ntt(int *c,int *a,int *b,int an,int bn){    if (!sig){        N=1;L=0;        while(N<=an+bn){N<<=1;L++;}        prepare();    }    fo(i,0,N-1)A[i]=B[i]=0;    fo(i,0,an)A[i]=a[i];    fo(i,0,bn)B[i]=b[i];    DFT(A);DFT(B);    fo(i,0,N-1)A[i]=(1ll*A[i]*B[i]%mo+mo)%mo;    IDFT(A);    fo(i,0,an+bn)c[i]=A[i];}void solve(){    sig=1;    for(N=2,L=1;N/2<n;N<<=1,L++){        prepare();        int half=N/2;        for(int x=1;x<=n;x+=N){            int y=x+half;            if (y>n)break;            fo(i,1,half)a[i]=g[x+i-1];            int rs=min(n-y+1,half);            fo(i,1,rs)b[i]=g[y+i-1];            ntt(c,a,b,half-1,rs-1);            fo(i,1,half-1)c[i+rs]=(c[i+rs]+1ll*b[rs]*a[i]%mo)%mo;            fo(i,1,rs)c[i+half]=(c[i+half]+1ll*b[i]*a[half]%mo)%mo;            fo(i,1,half){c[i]=(c[i]+a[i])%mo;a[i]=0;}            fo(i,1,rs){c[i]=(c[i]+b[i])%mo;b[i]=0;}            fo(i,1,half+rs){g[x+i-1]=c[i];c[i]=0;}        }    }    sig=0;}//以下部分是多项式求逆的部分void get_nv(int *f,int *f_,int n){    if (n==1){f_[0]=1;return;}    //其实这里更准确应该是f_[0]=quickmi(f[0],mo-2);由于本题求逆的多项式f[0]一定是1所以f_[0]一定是1    int n_=(n+1)/2;    get_nv(f,f_,n_);    N=1;L=0;    while(N<=n*2){N<<=1;L++;}    prepare();    fo(i,0,N-1)A[i]=B[i]=0;    fo(i,0,n-1)A[i]=f[i]%mo;    fo(i,0,n_-1)B[i]=f_[i]%mo;    DFT(A);DFT(B);    fo(i,0,N-1)A[i]=1ll*B[i]*((2ll+mo-1ll*A[i]*B[i]%mo)%mo)%mo;    IDFT(A);    fo(i,0,n-1)f_[i]=A[i];}int main(){    freopen("sum.in","r",stdin);    freopen("sum.out","w",stdout);    T=get();    mi[0]=1;    fo(i,1,21)mi[i]=mi[i-1]<<1;    while(T--){        n=get();        fo(i,1,n)g[i]=get();        solve();        fo(i,1,n)        if (i%2)t[i]=1ll*g[i]*i%mo;        else t[i]=(mo-1ll*g[i]*i%mo)%mo;        fo(i,1,n)        if (i%2)g[i]=(mo-g[i])%mo;        g[0]=1;        get_nv(g,g_nv,n+1);        ntt(f,g_nv,t,n,n);        int ans=0;        fo(i,1,n)ans=ans^f[i];        printf("%d\n",ans);        fo(i,0,N-1)a[i]=b[i]=g[i]=c[i]=0;    }    fclose(stdin);    fclose(stdout);    return 0;}
原创粉丝点击