UOJ86 mx的组合数

来源:互联网 发布:ubuntu tensorflow安装 编辑:程序博客网 时间:2024/06/07 14:30

大概看到题的时候就会做了。好厉害的题。

组合数模质数p等于某值的方案数,很容易想到利用卢卡斯定理。然后要使p进制下每一位对出的结果的乘积在模p意义下为某值,数位dp一波就好。

暴力转移是每位p^2的,但是转移的形式是c[i*j]+=a[i]*b[j],可以考虑找原根R,这就变成了c[logRi+logRj]+=a[logRi]*b[logRj],这里NTT就好了,模数还刚好是998244353。

然后你让我写。就很麻烦了。

先是一波高精度处理。找原根并预处理阶。在阶的基础上转换卷积形式以及NTT板子。哦还有预处理一波阶乘来求组合数。最后就是数位dp。

写还有调搞得我心力憔悴。

代码:

#include <cstdio>#include <cstring>#include <iostream>#define ll long long#define AwD 998244353int p;struct bigint{int v[105],len;}n,l,r;void read(bigint&a){char s[35];scanf("%s",s);a.len=strlen(s);for(int i=0;i<a.len;i++) a.v[a.len-i]=s[i]-'0';}int operator%(bigint a,int b){int res=0;for(int i=a.len;i;i--) res=(res*10+a.v[i])%b;return res;}bigint operator/(bigint a,int b){for(int i=a.len;i;i--){a.v[i-1]+=a.v[i]%b*10;a.v[i]/=b;}while(a.len>1&&!a.v[a.len]) a.len--;return a;}bigint incr(bigint a){a.v[1]++;for(int i=1;i<a.len;i++) if(a.v[i]>=10){a.v[i]-=10;a.v[i+1]++;}if(a.v[a.len]>=10){a.v[a.len]-=10;a.v[++a.len]=1;}return a;}bool zero(bigint a){return a.len==1&&!a.v[1];}void trs(bigint&a){int b[105],n=0;while(!zero(a)){b[++n]=a%p;a=a/p;}for(int i=1;i<=n;i++) a.v[i]=b[i];a.len=n;}void exp0(bigint&a,int L){for(int i=a.len+1;i<=L;i++) a.v[i]=0;}int kth[30005],rk[30005],R;void findR(){R=1;while(1){for(int i=1;i<p;i++) rk[i]=0;bool flag=kth[0]=1;for(int i=1;i<p;i++){kth[i]=kth[i-1]*R%p;if(rk[kth[i]]){flag=0;break;}rk[kth[i]]=i;}if(!flag){R++;continue;}rk[1]=0;return;}}ll pw(ll x,ll y){if(y<0) y+=AwD-1;if(!y) return 1;ll res=pw(x,y>>1);(res*=res)%=AwD;if(y&1) (res*=x)%=AwD;return res;}ll ntt(ll*a,int n,int d){int i,j,k;ll w,t,u,v;for(i=(n>>1),j=1;j<n;j++){if(i<j) t=a[i],a[i]=a[j],a[j]=t;for(k=(n>>1);i&k;i^=k,k>>=1);i^=k;}for(k=2;k<=n;k<<=1){w=pw(3,(AwD-1)/k*d);for(i=0;i<n;i+=k){t=1;for(j=i;j<i+(k>>1);j++){u=a[j];v=t*a[j+(k>>1)]%AwD;a[j]=(u+v)%AwD;a[j+(k>>1)]=(u-v+AwD)%AwD;t=t*w%AwD;}}}}ll t1[65555],t2[65555];void multi(int*a,int*b,int*res){int res0=0;for(int i=0;i<p;i++){res0=(res0+1ll*a[i]*b[0])%AwD;if(i) res0=(res0+1ll*a[0]*b[i])%AwD;}//for(int i=0;i<p;i++) printf("%d ",a[i]);printf("!!\n");//for(int i=0;i<p;i++) printf("%d ",b[i]);printf("!!\n");for(int i=1;i<p;i++) t1[rk[i]]=a[i],t2[rk[i]]=b[i];int l=1,invl;while(l<p-1) l<<=1;invl=pw(l<<=1,-1);for(int i=p-1;i<l;i++) t1[i]=t2[i]=0;//for(int i=0;i<l;i++) printf("%lld ",t1[i]);printf("!!\n");//for(int i=0;i<l;i++) printf("%lld ",t2[i]);printf("!!\n");ntt(t1,l,1);ntt(t2,l,1);for(int i=0;i<l;i++) (t1[i]*=t2[i])%=AwD;ntt(t1,l,-1);for(int i=0;i<l;i++) t1[i]=t1[i]*invl%AwD;//for(int i=0;i<l;i++) printf("%lld ",t1[i]);printf("!!\n");for(int i=1;i<p;i++) res[i]=0;for(int i=0;i<l;i++) (res[kth[i%(p-1)]]+=t1[i])%=AwD;res[0]=res0;//for(int i=0;i<p;i++) printf("%d ",res[i]);printf("!!\n");}int fac[30005],inv[30005];int C(int n,int m){return n<m?0:fac[n]*inv[m]%p*inv[n-m]%p;}int L,dp[105][30005],tmp[30005],cur;void init(){fac[0]=1;for(int i=1;i<p;i++) fac[i]=fac[i-1]*i%p;inv[p-1]=kth[p-1-rk[fac[p-1]]];for(int i=p-1;i;i--) inv[i-1]=inv[i]*i%p;for(int i=0;i<p;i++) dp[0][i]=i==1;for(int i=1;i<L;i++){for(int j=0;j<p;j++) tmp[j]=0;for(int j=0;j<p;j++) tmp[C(j,n.v[i])]++;multi(dp[i-1],tmp,dp[i]);}}void solve(bigint a,int*ans){//for(int i=0;i<L;i++,printf("\n")) for(int j=0;j<p;j++) printf("%d ",dp[i][j]);//printf("solving...\n");for(int i=0;i<p;i++) ans[i]=0;cur=1;for(int i=L;i;i--){for(int j=0;j<p;j++) tmp[j]=0;for(int j=0;j<a.v[i];j++) tmp[C(j,n.v[i])]++;//for(int j=0;j<p;j++) printf("%d ",tmp[j]);printf("\n");multi(tmp,dp[i-1],tmp);//for(int j=0;j<p;j++) printf("%d ",tmp[j]);printf("::\n");for(int j=0;j<p;j++) (ans[j*cur%p]+=tmp[j])%=AwD;(cur*=C(a.v[i],n.v[i]))%=p;}//printf("solved\n");}int ans1[30005],ans2[30005]; int main(){scanf("%d",&p);read(n);read(l);read(r);r=incr(r);//printf("reading ok\n");trs(n);trs(l);trs(r);//printf("transforming ok\n");L=std::max(n.len,std::max(l.len,r.len));//printf("L=%d\n",L);exp0(n,L);exp0(l,L);exp0(r,L);//for(int i=L;i;i--) printf("%d ",n.v[i]);printf("\n");//for(int i=L;i;i--) printf("%d ",l.v[i]);printf("\n");//for(int i=L;i;i--) printf("%d ",r.v[i]);printf("\n");//printf("0-expanding ok\n");findR();//printf("%d\n",R);//for(int i=0;i<p;i++) printf("%d ",kth[i]);printf("\n");//for(int i=1;i<p;i++) printf("%d ",rk[i]);printf("\n");//printf("---\n");init();solve(r,ans1);solve(l,ans2);for(int i=0;i<p;i++) printf("%d\n",(ans1[i]-ans2[i]+AwD)%AwD);}

0 0