bzoj3992: [SDOI2015]序列统计

来源:互联网 发布:手机淘宝网触屏版官网 编辑:程序博客网 时间:2024/05/01 00:23

传送门:http://www.lydsy.com:808/JudgeOnline/problem.php?id=3992

思路:M是一个质数,问题又是求乘积,于是我们就可以想到利用M的原根g把问题变成求和(我怎么想不到啊。。。)

根据原根的性质,我们可以把1到M-1中的数i表示为(g^b[i])%M,且指数互不相同

那么X就可以表示成(g^b[x])%M

问题就转化为:然后问题转化成了在序列b中,选出n个数(一个数可以取多次),且它们的和s满足s=b[x]

这个问题我们可以用母函数+NTT解决,答案就是多项式的b[x]次项的系数

因为n很大,所以再套一个快速幂即可

#include<cmath>#include<cstdio>#include<cstring>#include<algorithm>using namespace std;const int mod=(479<<21)+1,G=3,maxn=17000;int n,k,sum,m,inv_G,inv_N,T[maxn],vis[maxn],tim,N,rev[maxn],pos[maxn],root;int qpow(int a,int b){int res=1;for (;b;b>>=1){if (b&1) res=1ll*res*a%mod;a=1ll*a*a%mod;}return res;}bool check(int x){int now=1;++tim;for (int i=1;i<n;i++,now=now*x%n){if (vis[now]==tim) return 0;vis[now]=tim;}return 1;}int rever(int x){int res=0,len=N;while (len--) res<<=1,res^=(x&1),x>>=1;return res;}int findroot(){for (int i=2;i<=n;i++) if (check(i)) return i;}struct DFT{int a[maxn];void ntt(int op){for (int i=0;i<N;i++) if (rev[i]>i) swap(a[rev[i]],a[i]);int g=op==1?G:inv_G;for (int sz=2;sz<=N;sz<<=1){int t=qpow(g,(mod-1)/sz);for (int bg=0;bg<N;bg+=sz)for (int po=bg,w=1;po<bg+(sz>>1);po++){int x=a[po],y=1ll*a[po+(sz>>1)]*w%mod;a[po]=(x+y)%mod,a[po+(sz>>1)]=(x-y+mod)%mod;w=1ll*w*t%mod;}}if (op==-1) for (int i=0;i<N;i++) a[i]=1LL*a[i]*inv_N%mod;}}a,b;void qpow(){b.a[0]=1;for (;k;k>>=1){a.ntt(1);if (k&1){b.ntt(1);for (int i=0;i<N;i++) b.a[i]=1ll*b.a[i]*a.a[i]%mod;b.ntt(-1);for (int i=N-1;i>=n-1;i--) b.a[i-n+1]=(b.a[i-n+1]+b.a[i])%mod,b.a[i]=0;//因为是在mod m(代码里的n)的条件下进行的,所以要把和超过m的后半部分的答案加到前半部分去,而不是简单的清空}for (int i=0;i<N;i++) a.a[i]=1ll*a.a[i]*a.a[i]%mod;a.ntt(-1);for (int i=N-1;i>=n-1;i--) a.a[i-n+1]=(a.a[i-n+1]+a.a[i])%mod,a.a[i]=0;}}int main(){inv_G=qpow(G,mod-2);scanf("%d%d%d%d",&k,&n,&sum,&m);for (int i=1;i<=m;i++) scanf("%d",&T[i]);N=(int)ceil(log2(n))+1;for (int i=0;i<(1<<N);i++) rev[i]=rever(i);N=1<<N,inv_N=qpow(N,mod-2),root=findroot();for (int i=0,res=1;i<n-1;i++) pos[res]=i,res=res*root%n;for (int i=1;i<=m;i++) if (T[i]) a.a[pos[T[i]]]++;qpow(),printf("%d\n",b.a[pos[sum]]);return 0;}


0 0