解题报告:HDU_6061 RXD and functions NTT

来源:互联网 发布:国信蓝点java培训 编辑:程序博客网 时间:2024/06/07 04:07

题目链接


题意及官方题解:



思路:

先感谢Q巨指点Qrz...


先求得系数式:

拆开组合数:

把系数提取出来稍作变换:

整理一下:

得到:


然后就可以直接进行NTT了。。


代码;

#include<bits/stdc++.h>const int mod = 998244353;const int N = 4e5+10;const int g = 3;using namespace std;long long F1[N],F2[N],qp[30];int getLen(int a){  a<<=1;  int res = 1;  while(res<a)res<<=1;  return res;}long long qpow(long long x,long long y){   long long res = 1;   while(y){      if(y&1){         res = res * x % mod;      }x = x * x % mod;      y >>= 1;   }return res;}void brc(long long *a,int l){    for(int i=1,j=l/2;i<l-1;i++){        if(i<j)swap(a[i],a[j]);        int k=l/2;        while(j>=k){            j-=k;            k>>=1;        }        if(j<k)j+=k;    }}void ntt(long long *y,int l,int on){    brc(y,l);    int id = 0;    long long u,t,tmp;    for(int h=2;h<=l;h<<=1){        id++;        for(int j=0;j<l;j+=h){            long long w = 1;            for(int k=j;k<j+h/2;k++){                u = y[k];                t = w*y[k+h/2]%mod;                y[k] = u+t;                if(y[k]>=mod)y[k]-=mod;                y[tmp=k+h/2]=u-t;                if(y[tmp]<0)y[tmp]+=mod;                w = w*qp[id]%mod;            }        }    }if(on<0){        for(int i=1;i<l/2;i++)swap(y[i],y[l-i]);        long long ni = qpow(l,mod-2);        for(int i=0;i<l;i++)y[i] = (y[i]*ni)%mod;    }}int K[N];long long fro[N];long long ni[N];void init(){   for(int i=0;i<21;i++){        int t=1<<i;        qp[i]=qpow(g,(mod-1)/t);    }   fro[0] = ni[0] = 1;   for(int i=1;i<N;i++){      fro[i] = fro[i-1] * i % mod;   }ni[N-1] = qpow(fro[N-1],mod-2);   for(int i=N-1;i;i--){      ni[i-1] = ni[i] * i % mod;   }}long long ans[N];int main(){   init();   int n,m;   //freopen("1006.in","r",stdin);   //freopen("my_1006.out","w",stdout);   while(scanf("%d",&n)==1){      long long a = 0;      for(int i=0;i<=n;i++)scanf("%d",&K[i]);      scanf("%d",&m);      for(int i=0,x;i<m;i++){         scanf("%d",&x);         a -= x;         if(a<0)a+=mod;      }int l = getLen(n+1);      if(a==0){         for(int i=0;i<=n;i++){            printf("%d ",K[i]);         }printf("\n");         continue;      }      for(int i=0,aa=1;i<l;i++){         if(i<=n){            F1[i] = aa * ni[i] % mod;            F2[i] = fro[n-i] * K[n-i] % mod;         }else {            F1[i] = F2[i] = 0;         }aa = 1LL * aa * a % mod;      }ntt(F1,l,1);ntt(F2,l,1);      for(int i=0;i<l;i++){         F1[i] = F1[i] * F2[i] % mod;      }ntt(F1,l,-1);      for(int i=0;i<=n;i++){         ans[i] = F1[i]  ;         ans[i] = ( ans[i] % mod + mod ) * ni[n-i] % mod;      }for(int i=0;i<=n;i++){         printf("%I64d ",ans[n-i]);      }printf("\n");   }return 0;}








原创粉丝点击