【BZOJ】3992 [SDOI2015]序列统计 【离散对数下的NTT】

来源:互联网 发布:matlab 2017b mac 编辑:程序博客网 时间:2024/05/02 01:14

题目链接:【BZOJ】3992 [SDOI2015]序列统计

题目大意:给一个集合S,元素均为小于M的非负整数。现有一个长度N的数列,数列中的每个数均属于集合S。给定整数x,求合法数列的个数,满足数列中所有数的乘积modM的值为x。两个数列AB不同,当且仅当至少存在一个整数i,满足AiBi.方案数对1004535809取模。
题目分析:利用原根的性质,对下标做变换,设gP的原根,则下标i->gi,使得乘法变成加法,之后便可以卷积了。

#include <bits/stdc++.h>using namespace std ;typedef long long LL ;typedef pair < int , int > pii ;typedef unsigned long long ULL ;#define clr( a , x ) memset ( a , x , sizeof a )const int MAXN = 50000 ;const int Mod = 1004535809 ;LL res[MAXN] , x[MAXN] ;int mod , p , g , S[10] , top ;int n , m , X ;int vis[MAXN] ;int gp[MAXN] ;int pm ( LL x , int n , int mod ) {    LL res = 1 ;    while ( n ) {        if ( n & 1 ) res = res * x % mod ;        x = x * x % mod ;        n >>= 1 ;    }    return res ;}void preprocess () {    int n = mod - 1 ;    top = 0 ;    for ( int i = 2 ; i <= n ; ++ i ) {        if ( n % i == 0 ) {            S[top ++] = i ;            while ( n % i == 0 ) n /= i ;        }    }    if ( n > 1 ) S[top ++] = n ;    for ( int i = 2 ; ; ++ i ) {        int ok = 1 ;        for ( int j = 0 ; j < top ; ++ j ) {            if ( pm ( i , ( mod - 1 ) / S[j] , mod ) == 1 ) {                ok = 0 ;                break ;            }        }        if ( ok ) {            g = i ;            break ;        }    }}void FFT ( LL y[] , int n , int f ) {    for ( int i = 1 , j , k , t ; i < n ; ++ i ) {        for ( j = 0 , k = n >> 1 , t = i ; k ; k >>= 1 , t >>= 1 ) {            j = j << 1 | ( t & 1 ) ;        }        if ( i < j ) swap ( y[i] , y[j] ) ;    }    for ( int s = 2 , ds = 1 ; s <= n ; ds = s , s <<= 1 ) {        LL wn = pm ( 3 , ( Mod - 1 ) / s , Mod ) ;        if ( !f ) wn = pm ( wn , Mod - 2 , Mod ) ;        for ( int k = 0 ; k < n ; k += s ) {            LL w = 1 , t ;            for ( int i = k ; i < k + ds ; ++ i ) {                y[i + ds] = ( y[i] - ( t = w * y[i + ds] % Mod ) + Mod ) % Mod ;                y[i] = ( y[i] + t ) % Mod ;                w = w * wn % Mod ;            }        }    }}void pow ( int k ) {    int n = 1 ;    while ( n < mod + mod ) n <<= 1 ;    int nv = pm ( n , Mod - 2 , Mod ) ;    while ( k ) {        FFT ( x , n , 1 ) ;        if ( k & 1 ) {            FFT ( res , n , 1 ) ;            for ( int i = 0 ; i < n ; ++ i ) {                res[i] = res[i] * x[i] % Mod ;            }            FFT ( res , n , 0 ) ;            for ( int i = 0 ; i < n ; ++ i ) {                res[i] = res[i] * nv % Mod ;            }            for ( int i = p ; i < n ; ++ i ) {                res[i % p] = ( res[i % p] + res[i] ) % Mod ;                res[i] = 0 ;            }        }        for ( int i = 0 ; i < n ; ++ i ) {            x[i] = x[i] * x[i] % Mod ;        }        FFT ( x , n , 0 ) ;        for ( int i = 0 ; i < n ; ++ i ) {            x[i] = x[i] * nv % Mod ;        }        for ( int i = p ; i < n ; ++ i ) {            x[i % p] = ( x[i % p] + x[i] ) % Mod ;            x[i] = 0 ;        }        k >>= 1 ;    }}void solve () {    preprocess () ;    p = mod - 1 ;    clr ( vis , 0 ) ;    clr ( res , 0 ) ;    clr ( x , 0 ) ;    res[0] = 1 ;    gp[0] = 1 ;    for ( int i = 1 ; i < mod ; ++ i ) {        gp[i] = gp[i - 1] * g % mod ;    }    for ( int i = 0 , v ; i < m ; ++ i ) {        scanf ( "%d" , &v ) ;        vis[v] = 1 ;    }    for ( int i = 0 ; i < p ; ++ i ) {        x[i] = vis[gp[i]] ;    }    pow ( n ) ;    for ( int i = 0 ; i < p ; ++ i ) {        if ( gp[i] == X ) {            printf ( "%lld\n" , res[i] ) ;            return ;        }    }}int main () {    while ( ~scanf ( "%d%d%d%d" , &n , &mod , &X , &m ) ) solve () ;    return 0 ;}

压缩后代码:

#include <bits/stdc++.h>using namespace std;typedef long long LL;#define clr(a,x) memset(a,x,sizeof a)const int MAXN=50000,Mod=1004535809;LL res[MAXN],x[MAXN];int mod,p,g,S[10],top,n,m,X,vis[MAXN],gp[MAXN];int pm(LL x,int n,int mod,LL res=1){    for(;n;x=x*x%mod,n>>=1)if(n&1)res=res*x%mod;    return res;}void preprocess(){    int n=mod-1,top=0,ok;    for(int i=2;i<=n;++i)if(n%i==0){        S[top++]=i;        while(n%i==0)n/=i;    }    if(n>1)S[top++]=n;    for(g=2;ok=1;++g){        for(int j=0;j<top;++j)if(pm(g,(mod-1)/S[j],mod)==1)ok=0;        if(ok)break;    }}int FFT(LL y[],int n,int f){    for(int i=1,j,k,t;i<n;++i){        for(j=0,k=n>>1,t=i;k;k>>=1,t>>=1)j=j<<1|(t&1);        if(i<j)swap(y[i],y[j]);    }    for(int s=2,ds=1,k;s<=n;ds=s,s<<=1){        LL wn=pm(3,(Mod-1)/s,Mod),w,t;        if(!f)wn=pm(wn,Mod-2,Mod);        for(k=0;w=1,k<n;k+=s)for(int i=k;i<k+ds;++i,w=w*wn%Mod){            y[i+ds]=(y[i]-(t=w*y[i+ds]%Mod)+Mod)%Mod;            y[i]=(y[i]+t)%Mod;        }    }}void calc(LL x[],LL y[],int nv,int n){    for(int i=0;i<n;++i)x[i]=x[i]*y[i]%Mod;    FFT(x,n,0);    for(int i=0;i<n;++i)x[i]=x[i]*nv%Mod;    for(int i=p;i<n;++i)x[i%p]=(x[i%p]+x[i])%Mod,x[i]=0;}void pow(int k){    int n=1,nv;    while(n<mod+mod)n<<=1;    for(nv=pm(n,Mod-2,Mod),FFT(x,n,1);k;calc(x,x,nv,n),k>>=1,k&&FFT(x,n,1))        if(k&1)FFT(res,n,1),calc(res,x,nv,n);}void solve(){    preprocess();    p=mod-1;    clr(vis,0),clr(res,0),clr(x,0);    res[0]=gp[0]=1;    for(int i=1;i<mod;++i)gp[i]=gp[i-1]*g%mod;    for(int i=0,v;i<m;vis[v]=1,++i)scanf("%d",&v);    for(int i=0;i<p;++i)x[i]=vis[gp[i]];    pow(n);    for(int i=0;i<p;++i)if(gp[i]==X){        printf("%lld\n",res[i]);        return;    }}int main(){    while(~scanf("%d%d%d%d",&n,&mod,&X,&m))solve();    return 0;}
0 0