codeforces 755G 多项式

来源:互联网 发布:tengine windows安装 编辑:程序博客网 时间:2024/06/11 23:54

题意:n个球,分成一些组,一个组里可以有1个球或相邻的两个球。一个球只能在一个组里或不在任何组里。求组数为1,2,…m时的方案数。
正常的递推式:f[i][j]表示i个球分成j组的方案数。
f[i][j]=f[i1][j]+f[i1][j1]+f[i2][j1]

那么 f[n] 的生成函数满足:
fn(x)=fn1(x)+xfn1(x)+xfn2(x)
=(x+1)fn1(x)+xfn2(x)

然后设fi=C1T1(x)i+C2T2(x)i

其中T0(x),T1(x) 为方程T2(x)=(x+1)T(x)+x的两个根

解得:T1(x)=1+x+(x2+6x+1)2,T2(x)=1+x(x2+6x+1)2

f0=1,f1=1+x 代入解出C1,C2

然后通分:
fn(x)=(1+x+(x2+6x+1)2)n+1(1+x(x2+6x+1)2)n+1(x2+6x+1)

然后(1+x(x2+6x+1)2)n+1 最低项次数大于n可以直接舍掉。
然后就是多项式操作了。

#include <bits/stdc++.h>using namespace std;#define N (1<<16)+10#define ll long long#define mod 998244353const int inv2=499122177;int n,m,len;int tmp[N],sq[N],inv_sq[N],rt1[N],ln1[N],a[N],ans[N],inv[N];int test1[N],test2[N],test3[N];int qpow(int x,int y){    int ret=1;    while(y)    {        if(y&1)ret=(ll)ret*x%mod;        x=(ll)x*x%mod;y>>=1;    }    return ret;}void NTT(int *a,int len,int type){    for(int i=0,t=0;i<len;i++)    {        if(i<t)swap(a[i],a[t]);        for(int j=len>>1;(t^=j)<j;j>>=1);    }    for(int i=2;i<=len;i<<=1)    {        int wn=qpow(3,(mod-1)/i);        for(int j=0;j<len;j+=i)        {            int w=1,t;            for(int k=0;k<i>>1;k++,w=(ll)w*wn%mod)            {                t=(ll)a[j+k+(i>>1)]*w%mod;                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;                a[j+k]=(a[j+k]+t)%mod;            }        }    }    if(type==-1)    {        for(int i=1;i<len>>1;i++)swap(a[i],a[len-i]);        int t=qpow(len,mod-2);        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;    }}////////////////void test_root(int *a,int len){    memset(test1,0,sizeof(test1));    for(int i=0;i<len;i++)        test1[i]=a[i];    NTT(test1,len<<1,1);    for(int i=0;i<len<<1;i++)        test1[i]=(ll)test1[i]*test1[i]%mod;    NTT(test1,len<<1,-1);    for(int i=0;i<len;i++)        printf("#%d ",test1[i]);    puts("");}void test_inv(int *a,int *b,int len){    memset(test1,0,sizeof(test1));    memset(test2,0,sizeof(test2));    for(int i=0;i<len;i++)        test1[i]=a[i],test2[i]=b[i];    NTT(test1,len<<1,1);    NTT(test2,len<<1,1);    for(int i=0;i<len<<1;i++)        test3[i]=(ll)test1[i]*test2[i]%mod;    NTT(test3,len<<1,-1);    for(int i=0;i<len;i++)        printf("#%d ",test3[i]);    puts("");}////////////////void get_inv(int *a,int *b,int len){    static int tmp[N];    if(len==1)    {        b[0]=qpow(a[0],mod-2);        return;    }    get_inv(a,b,len>>1);    for(int i=0;i<len;i++)tmp[i]=a[i];    for(int i=len;i<len<<1;i++)tmp[i]=0;    NTT(tmp,len<<1,1);    NTT(b,len<<1,1);    for(int i=0;i<len<<1;i++)        b[i]=(ll)b[i]*(2-(ll)b[i]*tmp[i]%mod+mod)%mod;    NTT(b,len<<1,-1);    for(int i=len;i<len<<1;i++)b[i]=0;}void get_root(int *a,int *b,int len){    static int invb[N],tmp[N];    if(len==1){b[0]=1;return;}    get_root(a,b,len>>1);    for(int i=0;i<len<<1;i++)invb[i]=0;    get_inv(b,invb,len);    for(int i=0;i<len;i++)tmp[i]=a[i];    for(int i=len;i<len<<1;i++)tmp[i]=0;    NTT(tmp,len<<1,1);    NTT(b,len<<1,1);    NTT(invb,len<<1,1);    for(int i=0;i<len<<1;i++)        b[i]=(ll)inv2*(b[i]+(ll)tmp[i]*invb[i]%mod)%mod;    NTT(b,len<<1,-1);    for(int i=len;i<len<<1;i++)b[i]=0;}void get_ln(int *a,int *b,int len){    static int inva[N],a1[N];    for(int i=0;i<len<<1;i++)inva[i]=0;    get_inv(a,inva,len);    for(int i=0;i<len;i++)a1[i]=(ll)(i+1)*a[i+1]%mod;    for(int i=len;i<len<<1;i++)a1[i]=0;    NTT(a1,len<<1,1);    NTT(inva,len<<1,1);    for(int i=0;i<len<<1;i++)a1[i]=(ll)a1[i]*inva[i]%mod;    NTT(a1,len<<1,-1);    b[0]=0;    for(int i=1;i<len;i++)        b[i]=(ll)a1[i-1]*inv[i]%mod;    for(int i=len;i<len<<1;i++)b[i]=0;}void get_exp(int *a,int *b,int len){    static int lnb[N],tmp[N];    if(len==1){b[0]=1;return;}    get_exp(a,b,len>>1);    for(int i=0;i<len<<1;i++)lnb[i]=0;    get_ln(b,lnb,len);    for(int i=0;i<len;i++)tmp[i]=(a[i]-lnb[i]+mod)%mod;    tmp[0]++;    for(int i=len;i<len<<1;i++)tmp[i]=0;    NTT(b,len<<1,1);    NTT(tmp,len<<1,1);    for(int i=0;i<len<<1;i++)        b[i]=(ll)b[i]*tmp[i]%mod;    NTT(b,len<<1,-1);    for(int i=len;i<len<<1;i++)b[i]=0;}int main(){    //freopen("tt.in","r",stdin);    scanf("%d%d",&n,&m);    for(len=1;len<=m;len<<=1);    for(int i=1;i<len;i++)inv[i]=qpow(i,mod-2);    tmp[0]=1;tmp[1]=6;tmp[2]=1;    get_root(tmp,sq,len);    get_inv(sq,inv_sq,len);    rt1[0]=rt1[1]=1;    for(int i=0;i<len;i++)        rt1[i]=(ll)(rt1[i]+sq[i])%mod*inv2%mod;    get_ln(rt1,ln1,len);    for(int i=0;i<len;i++)        ln1[i]=(ll)ln1[i]*(n+1)%mod;    get_exp(ln1,a,len);    NTT(inv_sq,len<<1,1);    NTT(a,len<<1,1);    for(int i=0;i<len<<1;i++)        ans[i]=(ll)a[i]*inv_sq[i]%mod;    NTT(ans,len<<1,-1);    for(int i=1;i<=m;i++)        printf("%d ",i>n ? 0:ans[i]);    return 0;}
0 0