【XSY2666】排列问题 DP 容斥原理 分治FFT

来源:互联网 发布:vb net从入门到精通 编辑:程序博客网 时间:2024/06/08 19:03

题目大意

  有n种颜色的球,第i种有ai个。设m=ai。你要把这m个小球排成一排。有q个询问,每次给你一个x,问你有多少种方案使得相邻的小球同色的对数为x

  n10000,m200000

题解

  我们考虑把这些小球分段,每段内所有小球颜色相同,但相邻两段的小球颜色可以相同。

  设第i种颜色有bi段,那么分j段的方案数是(bi)!(bi!)=j!(bi!)

  那么先DP,设fi,j为前i种颜色,分了j段的方案数÷bi!显然枚举第i中颜色分k段得

fi,j+=fi1,jk×(ai1k1)×1k!

  那个组合数是插板法得到的。

  这个DP的时间复杂度是O(m2)(因为枚举第i种颜色时j=1ai,k=1sisa的前缀和))

  然后这个东西可以分治FFT优化到O(mlogmlogn)

  这样我们得到了分成i段的方案数gi=fn,i×i!,但相邻两端可能颜色相同。我们还要减掉这种情况。

  答案ansi=gij<igj(mjij)

  可以简单暴力的通过分治FFT优化到O(mlog2m)。但我们有更好的做法。

  考虑容斥。其实总的gjansi的贡献就是(1)ij(mjij)。直接FFT一次就可以得到答案。

ansk>i=j=ki1(1)jk(mkjk)(mjij)=j=ki1(1)jk(mk)!(mj)!(jk)!(mj)!(ij)!(mi)!=j=ki1(1)jk(mk)!(jk)!(ij)!(mi)!=(mk)!(mi)!(ik)!j=ki1(1)jk(ik)!(ij)!(jk)!=(mkik)j=ki1(1)jk(ikjk)=(mkik)(1)ik

  那么相邻的小球同色的对数为x的答案就是ansmx

  时间复杂度:O(mlogmlogn+q)

代码

#include<cstdio>#include<cstring>#include<algorithm>#include<cstdlib>#include<ctime>#include<utility>#include<cmath>#include<functional>#include<vector>#include<queue>using namespace std;typedef long long ll;typedef unsigned long long ull;typedef pair<int,int> pii;typedef pair<ll,ll> pll;void sort(int &a,int &b){    if(a>b)        swap(a,b);}void open(const char *s){#ifndef ONLINE_JUDGE    char str[100];    sprintf(str,"%s.in",s);    freopen(str,"r",stdin);    sprintf(str,"%s.out",s);    freopen(str,"w",stdout);#endif}int rd(){    int s=0,c;    while((c=getchar())<'0'||c>'9');    do    {        s=s*10+c-'0';    }    while((c=getchar())>='0'&&c<='9');    return s;}void put(int x){    if(!x)    {        putchar('0');        return;    }    static int c[20];    int t=0;    while(x)    {        c[++t]=x%10;        x/=10;    }    while(t)        putchar(c[t--]+'0');}int upmin(int &a,int b){    if(b<a)    {        a=b;        return 1;    }    return 0;}int upmax(int &a,int b){    if(b>a)    {        a=b;        return 1;    }    return 0;}const int p=998244353;int fp(int a,int b){    int s=1;    for(;b;b>>=1,a=1ll*a*a%p)        if(b&1)            s=1ll*s*a%p;    return s;}int inv[600010];int fac[600010];int ifac[600010];namespace ntt{    const int g=3;    int rev[600010];    int w1[600010];    int w2[600010];    int n;    void init(int m)    {        n=1;        while(n<=m)            n<<=1;        int i;        rev[0]=0;        for(i=1;i<n;i++)            rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);        for(i=1;i<=n;i<<=1)        {            w1[i]=fp(g,(p-1)/i);            w2[i]=fp(w1[i],p-2);        }    }    void ntt(int *a,int t)    {        int i,j,k;        int u,v,w,wn;        for(i=0;i<n;i++)            if(rev[i]<i)                swap(a[i],a[rev[i]]);        for(i=2;i<=n;i<<=1)        {            wn=(t==1?w1[i]:w2[i]);            for(j=0;j<n;j+=i)            {                w=1;                for(k=j;k<j+i/2;k++)                {                    u=a[k];                    v=1ll*a[k+i/2]*w%p;                    a[k]=(u+v)%p;                    a[k+i/2]=(u-v)%p;                    w=1ll*w*wn%p;                }            }        }        if(t==-1)        {            int inv=fp(n,p-2);            for(i=0;i<n;i++)                a[i]=1ll*a[i]*inv%p;        }    }};int g[600010];int h[600010];int ans[600010];int a[600010];int s[600010];int n,m;void add(int &a,int b){    a=(a+b)%p;}typedef vector<int> vec;vec mul(vec &a,vec &b){    static int c[600010],d[600010];    int n1=a.size()-1;    int n2=b.size()-1;    int m=n1+n2+1;    ntt::init(m);    int i;    for(i=0;i<=n1;i++)        c[i]=a[i];    for(i=n1+1;i<ntt::n;i++)        c[i]=0;    for(i=0;i<=n2;i++)        d[i]=b[i];    for(i=n2+1;i<ntt::n;i++)        d[i]=0;    ntt::ntt(c,1);    ntt::ntt(d,1);    for(i=0;i<ntt::n;i++)        c[i]=1ll*c[i]*d[i]%p;    ntt::ntt(c,-1);    vec s(n1+n2+1);    for(i=1;i<=n1+n2;i++)        s[i]=c[i];    return s;}vec solve(int l,int r){    if(l==r)    {        vec s(a[l]+1);        int i;        for(i=1;i<=a[l];i++)            s[i]=1ll*ifac[i-1]*ifac[i]%p*ifac[a[l]-i]%p;        return s;    }    int mid=(l+r)>>1;    vec s1=solve(l,mid);    vec s2=solve(mid+1,r);    return mul(s1,s2);}int c[600010];int d[600010];priority_queue<pii,vector<pii>,greater<pii> > q;void gao(){    int i;    c[0]=0;    for(i=1;i<=m;i++)        c[i]=g[i];    for(i=0;i<=m;i++)    {        d[i]=ifac[i];        if(i&1)            d[i]=-d[i];    }    ntt::init(2*m);    for(i=m+1;i<ntt::n;i++)        c[i]=d[i]=0;    ntt::ntt(c,1);    ntt::ntt(d,1);    for(i=0;i<ntt::n;i++)        c[i]=1ll*c[i]*d[i]%p;    ntt::ntt(c,-1);    for(i=1;i<=m;i++)        g[i]=c[i];}int t=0;vec f[20010];int main(){    open("c");    scanf("%d",&n);    int i;    for(i=1;i<=n;i++)    {        scanf("%d",&a[i]);        s[i]=s[i-1]+a[i];    }    m=s[n];    inv[0]=inv[1]=fac[0]=fac[1]=ifac[0]=ifac[1]=1;    for(i=2;i<=m;i++)    {        inv[i]=-1ll*p/i*inv[p%i]%p;#ifndef ONLINE_JUDGE        inv[i]=(inv[i]+p)%p;#endif        fac[i]=1ll*fac[i-1]*i%p;        ifac[i]=1ll*ifac[i-1]*inv[i]%p;    }//  f[0][0]=1;    int times=1;    for(i=1;i<=n;i++)        times=1ll*times*fac[a[i]-1]%p;//  for(i=1;i<=n;i++)//  {//      times=times*fac[a[i]-1]%p;//      for(j=1;j<=s[i];j++)//      {//          for(k=1;k<=a[i]&&k<=j;k++)//              add(f[i][j],f[i-1][j-k]*ifac[k-1]%p*ifac[a[i]-k]%p*ifac[k]%p);////                add(f[i][j],f[i-1][j-k]*c(a[i]-1,k-1)%p*ifac[k]%p);////            f[i][j]=f[i][j]*fac[a[i]-1]%p;//      }//  }    int j;    for(i=1;i<=n;i++)    {        f[i].resize(a[i]+1);        for(j=1;j<=a[i];j++)            f[i][j]=1ll*ifac[j-1]*ifac[j]%p*ifac[a[i]-j]%p;        q.push(pii(a[i],i));    }    t=n;    for(i=1;i<n;i++)    {        int n1=q.top().first;        int x=q.top().second;        q.pop();        int n2=q.top().first;        int y=q.top().second;        q.pop();        f[++t]=mul(f[x],f[y]);        f[x].clear();        f[y].clear();        q.push(pii(n1+n2+1,t));    }    vec ss=f[t];//  vec ss=solve(1,n);    for(i=1;i<=m;i++)        g[i]=1ll*ss[i]*fac[i]%p*times%p;#ifndef ONLINE_JUDGE    for(i=1;i<=m;i++)        add(g[i],p);#endif//      g[i]=f[n][i]*fac[i]%p*times%p;      for(i=1;i<=m;i++)        g[i]=1ll*g[i]*fac[m-i]%p;    gao();    for(i=1;i<=m;i++)    {        g[i]=1ll*g[i]*ifac[m-i]%p;        add(g[i],p);    }//  for(i=1;i<=m;i++)//  {//      for(j=1;j<i;j++)//          add(ans[i],h[j]%p*ifac[i-j]%p);//      ans[i]=-ans[i]*ifac[m-i]%p;//      ans[i]=(ans[i]+g[i])%p;//          add(ans[i],-ans[j]*c(m-j,i-j));//      add(ans[i],p);//      h[i]=ans[i]*fac[m-i]%p;//  }    int q;    int x;    scanf("%d",&q);    while(q--)    {        scanf("%d",&x);        printf("%lld\n",g[m-x]);    }    return 0;}