51nod 1348 乘积之和 分治+NTT+中国剩余定理

来源:互联网 发布:淘宝橙色cmyk 编辑:程序博客网 时间:2024/05/16 17:32

题意

给出由N个正整数组成的数组A,有Q次查询,每个查询包含一个整数K,从数组A中任选K个(K <= N)把他们乘在一起得到一个乘积。求所有不同的方案得到的乘积之和,由于结果巨大,输出Mod 100003的结果即可。例如:1 2 3,从中任选1个共3种方法,{1} {2} {3},和为6。从中任选2个共3种方法,{1 2} {1 3} {2 3},和为2 + 3 + 6 = 11。
1 <= N, Q <= 50000,1 <= A[i] <= 10^9

分析

没有想到分治。。。
可以把整个序列分成左右两边单独处理,然后合并的话就是一个卷积的形式,可以用NTT来搞。
但这里的模数并不是NTT模数。注意到每次卷积后的最大值在1016左右,不知道能不能用FFT来搞。这里可以用两个模数分别NTT,然后用天朝剩余定理合并就好了。
注意这里是每次NTT完之后都要合并一下,而不是全部搞完之后再合并。

代码

#include<iostream>#include<cstdio>#include<cstdlib>#include<cstring>#include<algorithm>using namespace std;typedef long long LL;const int N=50005;const int mod1=998244353;const int mod2=1004535809;const int MOD=100003;int n,q,val[N],a[21][N*4],rev[N*4],stack[21],top,mod,b1[21][N*4],b2[21][N*4],f1[N*4],f2[N*4];int read(){    int x=0,f=1;char ch=getchar();    while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}    while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}    return x*f;}int ksm(int x,int y,int p){    int ans=1;    while (y)    {        if (y&1) ans=(LL)ans*x%p;        x=(LL)x*x%p;y>>=1;    }    return ans;}void NTT(int *a,int L,int f){    for (int i=0;i<L;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);    for (int i=1;i<L;i<<=1)    {        int wn=ksm(3,f==1?(mod-1)/i/2:mod-1-(mod-1)/i/2,mod);        for (int j=0;j<L;j+=(i<<1))        {            int w=1;            for (int k=0;k<i;k++)            {                int u=a[j+k],v=(LL)a[j+k+i]*w%mod;                a[j+k]=(u+v)%mod;a[j+k+i]=(u-v+mod)%mod;                w=(LL)w*wn%mod;            }        }    }    int ny=ksm(L,mod-2,mod);    if (f==-1) for (int i=0;i<L;i++) a[i]=(LL)a[i]*ny%mod;}LL mul(LL x,LL y,LL mo){    LL tmp=(x*y-(LL)((double)x*y/mo+0.1)*mo)%mo;    if (tmp<0) tmp+=mo;    return tmp;}LL merge(LL m1,LL m2){    LL m=(LL)mod1*mod2;    return (mul((LL)mod2*ksm(mod2,mod1-2,mod1),m1,m)+mul((LL)mod1*ksm(mod1,mod2-2,mod2),m2,m))%m;}void solve(int l,int r,int id){    if (l==r) {a[id][0]=1;a[id][1]=val[l];return;}    int lg=0,mid=(l+r)/2,L,len=r-l+1;    for (L=1;L<=len*2;L<<=1,lg++);    solve(l,mid,id+1);    for (int i=0;i<=mid-l+1;i++) b1[id][i]=a[id+1][i],a[id+1][i]=0;    solve(mid+1,r,id+1);    for (int i=0;i<=r-mid;i++) b2[id][i]=a[id+1][i],a[id+1][i]=0;    for (int i=0;i<L;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));    mod=mod1;    for (int i=0;i<=mid-l+1;i++) f1[i]=b1[id][i];    for (int i=0;i<=r-mid;i++) f2[i]=b2[id][i];    NTT(f1,L,1);NTT(f2,L,1);    for (int i=0;i<L;i++) f1[i]=(LL)f1[i]*f2[i]%mod,f2[i]=0;    NTT(f1,L,-1);    for (int i=0;i<=len;i++) a[id][i]=f1[i],f1[i]=0;    mod=mod2;    for (int i=0;i<=mid-l+1;i++) f1[i]=b1[id][i];    for (int i=0;i<=r-mid;i++) f2[i]=b2[id][i];    NTT(f1,L,1);NTT(f2,L,1);    for (int i=0;i<L;i++) f1[i]=(LL)f1[i]*f2[i]%mod,f2[i]=0;    NTT(f1,L,-1);    for (int i=0;i<=len;i++) a[id][i]=merge(a[id][i],f1[i])%MOD,f1[i]=0;}int main(){    n=read();q=read();    for (int i=1;i<=n;i++) val[i]=read()%MOD;    solve(1,n,0);    while (q--)    {        int x=read();        printf("%d\n",(a[0][x]+MOD)%MOD);    }    return 0;}
原创粉丝点击