BZOJ 4555 求和(生成函数+FFT)

来源:互联网 发布:mac可以装谷歌浏览器吗 编辑:程序博客网 时间:2024/06/18 05:47

Description

在2016年,佳媛姐姐刚刚学习了第二类斯特林数,非常开心。

现在他想计算这样一个函数的值:f(n)=i=0nj=0nS(i,j)×2j×j!

S(i,j)表示第二类斯特林数,递推公式为:

S(i,j)=jS(i1,j)+S(i1,j1),1ji1

边界条件为:S(i,i)=1(i0),S(i,0)=0(i1)

你能帮帮他吗?

Input

输入只有一个正整数n(1n100000)

Output

输出f(n)。由于结果会很大,输出f(n)998244353(7×17×223+1)取模的结果即可

Sample Input

3

Sample Output

87

Solution

g(i)=j=0nS(i,j)×2j×j!,则由第二类斯特林数的定义知g(i)的意义是将i个数分到j个不同的非空集合,且每多一个集合其对答案的贡献就乘2,考虑第一个集合中元素个数为k,则有g(i)=k=1i2Ckig(ik)

化简得到卷积形式g(i)i!=j=1ig(ij)(ij)!2j!

F(x)=i=0g(i)i!xi,G(x)=i=12i!xi,则有F(x)=F(x)G(x)+1,故有F(x)=11G(x)

多项式求逆得到F(x),进而得到g(0),...,g(n),得到答案ans=i=0ng(i)

Code

#include<cstdio>#include<iostream>#include<cstring>#include<algorithm>#include<cmath>#include<vector>#include<queue>#include<map>#include<set>#include<ctime>using namespace std;typedef long long ll;#define maxn 100005#define maxfft 262144+5#define mod 998244353const double pi=acos(-1.0);struct cp {    double a,b;    cp operator +(const cp &o)const {return (cp){a+o.a,b+o.b};}    cp operator -(const cp &o)const {return (cp){a-o.a,b-o.b};}    cp operator *(const cp &o)const {return (cp){a*o.a-b*o.b,b*o.a+a*o.b};}    cp operator *(const double &o)const {return (cp){a*o,b*o};}    cp operator !() const{return (cp){a,-b};}}w[maxfft];int pos[maxfft];void fft_init(int len){    int j=0;    while((1<<j)<len)j++;    j--;    for(int i=0;i<len;i++)        pos[i]=pos[i>>1]>>1|((i&1)<<j);}void fft(cp *x,int len,int sta){    for(int i=0;i<len;i++)        if(i<pos[i])swap(x[i],x[pos[i]]);    w[0]=(cp){1,0};    for(unsigned i=2;i<=len;i<<=1)    {        cp g=(cp){cos(2*pi/i),sin(2*pi/i)*sta};        for(int j=i>>1;j>=0;j-=2)w[j]=w[j>>1];        for(int j=1;j<i>>1;j+=2)w[j]=w[j-1]*g;        for(int j=0;j<len;j+=i)        {            cp *a=x+j,*b=a+(i>>1);            for(int l=0;l<i>>1;l++)            {                cp o=b[l]*w[l];                b[l]=a[l]-o;                a[l]=a[l]+o;            }        }    }    if(sta==-1)for(int i=0;i<len;i++)x[i].a/=len,x[i].b/=len;}cp x[maxfft],y[maxfft],z[maxfft];int temp[maxfft];void FFT(int *a,int *b,int n,int m,int *c){    if(n<=100&&m<=100||min(n,m)<=5)    {        for(int i=0;i<n+m-1;i++)temp[i]=0;        for(int i=0;i<n;i++)            for(int j=0;j<m;j++)            {                temp[i+j]+=(ll)a[i]*b[j]%mod;                if(temp[i+j]>=mod)temp[i+j]-=mod;            }        for(int i=0;i<n+m-1;i++)c[i]=temp[i];        return ;    }    int len=1;    while(len<n+m)len<<=1;    fft_init(len);    for(int i=0;i<len;i++)    {        int aa=i<n?a[i]:0,bb=i<m?b[i]:0;        x[i]=(cp){(aa>>15),(aa&32767)},y[i]=(cp){(bb>>15),(bb&32767)};    }    fft(x,len,1),fft(y,len,1);    for(int i=0;i<len;i++)    {        int j=len-1&len-i;        z[i]=((x[i]+!x[j])*(y[i]-!y[j])+(x[i]-!x[j])*(y[i]+!y[j]))*(cp){0,-0.25};    }    fft(z,len,-1);    for(int i=0;i<n+m-1;i++)    {        ll ta=(ll)(z[i].a+0.5)%mod;        ta=(ta<<15)%mod;        c[i]=ta;    }    for(int i=0;i<len;i++)    {        int j=len-1&len-i;        z[i]=(x[i]-!x[j])*(y[i]-!y[j])*(cp){-0.25,0}+(x[i]+!x[j])*(y[i]+!y[j])*(cp){0,0.25};    }    fft(z,len,-1);    for(int i=0;i<n+m-1;i++)    {        ll ta=(ll)(z[i].a+0.5)%mod,tb=(ll)(z[i].b+0.5)%mod;        ta=(ta+(tb<<30))%mod;        c[i]=(c[i]+ta)%mod;    }}int inv[maxn],finv[maxn],fact[maxn];void init(int n=100001){    inv[1]=1;    for(int i=2;i<=n;i++)inv[i]=mod-(ll)(mod/i)*inv[mod%i]%mod;    finv[0]=1;    for(int i=1;i<=n;i++)finv[i]=(ll)finv[i-1]*inv[i]%mod;    fact[0]=1;    for(int i=1;i<=n;i++)fact[i]=(ll)fact[i-1]*i%mod;}int temp1[maxfft];void Poly_Inv(int *poly,int n,int *ans){    ans[0]=inv[poly[0]];    for(int i=2;i<=n;i<<=1)    {        FFT(poly,ans,i,i/2,temp1);        FFT(ans,temp1+i/2,i/2,i/2,temp1);        for(int j=0;j<i/2;j++)ans[j+i/2]=temp1[j]==0?0:mod-temp1[j];    }}int f[maxfft],g[maxfft];int main(){    init();    int n;    while(~scanf("%d",&n))    {        int len=1;        while(len<=n)len<<=1;        g[0]=1;        for(int i=1;i<=n;i++)g[i]=(mod-(finv[i]<<1)%mod)%mod;        for(int i=n+1;i<len;i++)g[i]=0;        Poly_Inv(g,len,f);        int ans=0;        for(int i=0;i<=n;i++)        {            ans+=(ll)f[i]*fact[i]%mod;            if(ans>=mod)ans-=mod;        }        printf("%d\n",ans);    }    return 0;}
原创粉丝点击