Tyvj4879(dp+倍增+NTT)

来源:互联网 发布:telnet 的端口号是什么 编辑:程序博客网 时间:2024/06/17 22:10

题面
题意:有n组数,每个数为1到6,每组有m个,每组内不计顺序。问总共至少有x个6的方案数。模998244353。n,m≤400。

看到这个熟悉的数字,大概就是个NTT了,然后考虑哪里有卷积。设g[i]为一组内有i个6的方案数,f[i]为两组内有i个6的方案数,有

f[i]=j=0ig[j]g[ij]

就是个卷积了。

对于有n组,我们考虑倍增,先求出n/2组的g数组,然后ntt合并即可。

考虑怎么求g[i]。由于每组内无关顺序,根据套路,我们可以强行设定单调。
设h[i][j]为第1~i个数非严格单调,第i个数为j的方案数。字面意思转移即可。

我的直觉认为,在倍增途中,若每一次卷积的次数界都动态变化,则总复杂度可以省一个log。

#include <iostream>#include <fstream>#include <algorithm>#include <cmath>#include <ctime>#include <cstdio>#include <cstdlib>#include <cstring>using namespace std;#define mmst(a, b) memset(a, b, sizeof(a))#define mmcp(a, b) memcpy(a, b, sizeof(b))typedef long long LL;const int N=800400;const LL p=998244353,g=3;int n,rev[N];LL cheng(LL a,LL b){    LL res=1ll;    for(;b;b>>=1,a=a*a%p)    if(b&1)    res=res*a%p;    return res;}void init(int lim){    int k=-1;    n=1;    while(n<lim)    n<<=1,k++;    for(int i=0;i<n;i++)    rev[i]=(rev[i>>1]>>1) | ((i&1)<<k);}void ntt(LL *a,int ops){    for(int i=0;i<n;i++)    if(i<rev[i])    swap(a[i],a[rev[i]]);    for(int l=2;l<=n;l<<=1)    {        int m=(l>>1);        LL wn;        if(ops)        wn=cheng(g,(p-1)/l);        else        wn=cheng(g,p-1-(p-1)/l);        for(int i=0;i<n;i+=l)        {            LL w=1ll;            for(int k=0;k<m;k++)            {                LL t=a[i+k+m]*w%p;                a[i+k+m]=(a[i+k]-t+p)%p;                a[i+k]=(a[i+k]+t)%p;                w=w*wn%p;            }        }    }    if(!ops)    {        LL Inv=cheng(n,p-2);        for(int i=0;i<n;i++)        a[i]=a[i]*Inv%p;    }}int nn,mm,x,y;LL f[401][6],gg[N],ans[N],by[N];void work(int b){    int csj=mm+1;    for(;b;b>>=1)    {        init(csj*2);        if(b&1)        {            ntt(ans,1);            for(int i=0;i<n;i++)            by[i]=gg[i];            ntt(by,1);            for(int i=0;i<n;i++)            ans[i]=(ans[i]*by[i])%p;            ntt(ans,0);        }        ntt(gg,1);        for(int i=0;i<n;i++)        gg[i]=gg[i]*gg[i]%p;        ntt(gg,0);        csj*=2;    }}int main(){    cin>>nn>>mm>>x>>y;    for(int i=1;i<=mm;i++)    {        int tu;        scanf("%d",&tu);        if(tu==y)        x--;    }    f[0][1]=1ll;    for(int i=1;i<=mm;i++)    for(int j=1;j<=5;j++)    for(int k=j;k<=5;k++)    f[i][k]=(f[i][k]+f[i-1][j])%p;    for(int i=0;i<=mm;i++)    {        LL hy=0;        for(int j=1;j<=5;j++)        hy=(hy+f[i][j])%p;        gg[mm-i]=hy;    }    ans[0]=1ll;    work(nn);    LL biu=0;    for(int i=x;i<n;i++)    biu=(biu+ans[i])%p;    cout<<biu<<endl;    return 0;}
原创粉丝点击