【BZOJ4944】【NOI2017】泳池 概率DP 常系数线性递推 特征多项式 多项式取模

来源:互联网 发布:恋舞刷m币软件 编辑:程序博客网 时间:2024/05/01 00:00

题目大意

  有一个1001×n的的网格,每个格子有q的概率是安全的,1q的概率是危险的。

  定义一个矩形是合法的当且仅当:

  • 这个矩形中每个格子都是安全的
  • 必须紧贴网格的下边界

  问你最大的合法子矩形大小为k的概率是多少。

  n109,k1000

  吉老师:这题本来是k20000

题解

  一道好题。

  我们计算最大子矩形不超过i的答案si,那么答案就是sksk1

  显然最后一行连续的安全格子不会超过k个。

  设gi,j表示长度为j,高度为i的海域全部是安全的,剩下的部分未知,最大子矩形k的概率。

  设hi,j表示长度为j,高度为i+1的海域中,前i行全部是安全的,剩下的未知且(i+1,j)是危险的,最大子矩形k的概率。

  边界:

gk,1gi,0hi,0=qk(1q)=1=1

  那么我们从k11DP,对于ij列,枚举第i+1行的下一个危险的格子在哪个地方,然后转移:
gi,jhi,j=k=0jhi,kgi+1,jk=k=0j1hi,kgi+1,jk1qi(1q)

  因为第i行的宽度不会超过ki,所以的暴力的时间复杂度是ki=1ki2=O(k2)

  这已经足够了,但我们可以做的更好。

  设

Ai(x)Bi(x)ci=j0gi,jxj=j0hi,jxj=qi(1q)

那么
Ai(x)Bi(x)Bi(x)=Bi(x)Ai+1(x)=cixAi+1(x)Bi(x)+1=11cixAi+1(x)

  时间复杂度是ki=1kilogki=O(klog2k)

  设fi为前i列最大子矩形k的概率,那么

fi=j=1kfij1g1,j(1q)

  这就是一个常系数线性递推。
aifi=g1,i1(1q)=j=1kfijaj

  时间复杂度:

  • 暴力:O(nk)70pts
  • 矩阵快速幂:O(k3logn)90pts
  • 特征多项式+暴力:O(k2logn)100pts
  • 特征多项式+NTT取模:O(klogklogn)100pts

  这里简单讲一下最后一个做法

  矩阵快速幂是给你一个矩阵A,求(An)1,1

  设矩阵的大小为k

  根据Cayley-Hamilton定理,|λIA|是一个关于λk次多项式,记为g(λ)。对于任意矩阵A,有g(A)=0

  对于常系数线性递推的矩阵,设fi=kj=1fijajg(λ)=λkk1i=0ai+1λi

  所以我们只需要求Anmodg(A)。可以用快速幂(倍增取模)求解。

  然后还要求出f1fk,可以通过其他方法计算(多项式求逆或者题目给你了)。

  最后一次卷积可以得到答案。

  如果要求fnk+1fn,那就把f1f2k带进去卷积。

  总时间复杂度:O(klog2k+klogklogn)

代码

  暴力取模

#include<cstdio>#include<cstring>#include<algorithm>#include<cstdlib>#include<ctime>#include<utility>#include<cmath>#include<functional>using namespace std;typedef long long ll;typedef unsigned long long ull;typedef pair<int,int> pii;typedef pair<ll,ll> pll;void sort(int &a,int &b){    if(a>b)        swap(a,b);}void open(const char *s){#ifndef ONLINE_JUDGE    char str[100];    sprintf(str,"%s.in",s);    freopen(str,"r",stdin);    sprintf(str,"%s.out",s);    freopen(str,"w",stdout);#endif}int rd(){    int s=0,c;    while((c=getchar())<'0'||c>'9');    do    {        s=s*10+c-'0';    }    while((c=getchar())>='0'&&c<='9');    return s;}int upmin(int &a,int b){    if(b<a)    {        a=b;        return 1;    }    return 0;}int upmax(int &a,int b){    if(b>a)    {        a=b;        return 1;    }    return 0;}ll p=998244353;void add(ll &a,ll b){    a=(a+b)%p;}ll fp(ll a,ll b){    ll s=1;    for(;b;b>>=1,a=a*a%p)        if(b&1)            s=s*a%p;    return s;}ll inv(ll a){    return fp(a,p-2);}ll pw1[1010];ll pw2[1010];ll q;ll q2;ll g[1010][1010];ll h[1010][1010];ll f[2010];ll a[2010];ll c[2010];ll d[2010];ll final[2010];void mul(ll *a,ll *b,ll *e,int len){    static ll c[2010];    int i,j;    for(i=0;i<=2*len;i++)        c[i]=0;    for(i=0;i<=len;i++)        for(j=0;j<=len;j++)            add(c[i+j],a[i]*b[j]);    for(i=2*len;i>=len;i--)    {        ll v=c[i]*inv(e[len]);        if(v)            for(j=0;j<=len;j++)                c[i-len+j]=(c[i-len+j]-e[j]*v)%p;    }    for(i=0;i<=len;i++)        a[i]=c[i];}ll solve(int n,int k){    if(!k)        return fp(q2,n);    memset(g,0,sizeof g);    memset(h,0,sizeof h);    g[k][1]=q2*pw1[k]%p;    g[k][0]=1;    int i,j,l;    for(i=k-1;i>=1;i--)    {        int m=k/i;        g[i][0]=1;        h[i][0]=1;        for(j=0;j<=m;j++)        {            for(l=j+1;l<=m;l++)                add(h[i][l],h[i][j]*g[i+1][l-j-1]%p*q2%p*pw1[i]%p);            for(l=j;l<=m;l++)                if(l)                    add(g[i][l],h[i][j]*g[i+1][l-j]%p);        }    }    memset(f,0,sizeof f);    f[0]=1;    for(i=1;i<=2*(k+1);i++)        for(j=0;j<i&&j<=k;j++)            add(f[i],f[i-j-1]*q2%p*g[1][j]);    if(n<=2*(k+1))    {        ll s=0;        for(i=0;i<=n&&i<=k;i++)            add(s,f[n-i]*g[1][i]);        return s;    }    int len=k+1;    for(i=0;i<len;i++)        a[i]=-q2*g[1][len-i-1]%p;    a[len]=1;    memset(c,0,sizeof c);    c[1]=1;    memset(d,0,sizeof d);    d[0]=1;    int m=n-k-1;    while(m)    {        if(m&1)            mul(d,c,a,len);        mul(c,c,a,len);        m>>=1;    }    memset(final,0,sizeof final);    for(i=1;i<=k+1;i++)        for(j=0;j<=k;j++)            add(final[i],d[j]*f[i+j]);    ll s=0;    for(i=1;i<=k+1;i++)        add(s,final[i]*g[1][k+1-i]);    return s;}int main(){    open("bzoj4944");    int n,k,x,y;    scanf("%d%d%d%d",&n,&k,&x,&y);    q=x*inv(y)%p;    q2=(y-x)*inv(y)%p;    pw1[0]=pw2[0]=1;    int i;    for(i=1;i<=k;i++)    {        pw1[i]=pw1[i-1]*q%p;        pw2[i]=pw2[i-1]*q2%p;    }    ll ans1=solve(n,k);    ll ans2=solve(n,k-1);    ll ans=((ans1-ans2)%p+p)%p;    printf("%lld\n",ans);    return 0;}

  NTT取模

#include<cstdio>#include<cstring>#include<algorithm>#include<cstdlib>#include<ctime>#include<utility>#include<cmath>#include<functional>using namespace std;typedef long long ll;typedef unsigned long long ull;typedef pair<int,int> pii;typedef pair<ll,ll> pll;void sort(int &a,int &b){    if(a>b)        swap(a,b);}void open(const char *s){#ifndef ONLINE_JUDGE    char str[100];    sprintf(str,"%s.in",s);    freopen(str,"r",stdin);    sprintf(str,"%s.out",s);    freopen(str,"w",stdout);#endif}int rd(){    int s=0,c;    while((c=getchar())<'0'||c>'9');    do    {        s=s*10+c-'0';    }    while((c=getchar())>='0'&&c<='9');    return s;}int upmin(int &a,int b){    if(b<a)    {        a=b;        return 1;    }    return 0;}int upmax(int &a,int b){    if(b>a)    {        a=b;        return 1;    }    return 0;}const ll p=998244353;const int maxn=300000;ll fp(ll a,ll b){    ll s=1;    for(;b;b>>=1,a=a*a%p)        if(b&1)            s=s*a%p;    return s;}namespace ntt{    const ll g=3;    ll w1[maxn];    ll w2[maxn];    int rev[maxn];    int n;    void init(int m)    {        n=1;        while(n<m)            n<<=1;        int i;        for(i=2;i<=n;i<<=1)        {            w1[i]=fp(g,(p-1)/i);            w2[i]=fp(w1[i],p-2);        }        rev[0]=0;        for(i=1;i<n;i++)            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));    }    void ntt(ll *a,int t)    {        int i,j,k;        ll u,v,w,wn;        for(i=0;i<n;i++)            if(rev[i]<i)                swap(a[i],a[rev[i]]);        for(i=2;i<=n;i<<=1)        {            wn=(t==1?w1[i]:w2[i]);            for(j=0;j<n;j+=i)            {                w=1;                for(k=j;k<j+i/2;k++)                {                    u=a[k];                    v=a[k+i/2]*w%p;                    a[k]=(u+v)%p;                    a[k+i/2]=(u-v)%p;                    w=w*wn%p;                }            }        }        if(t==-1)        {            u=fp(n,p-2);                for(i=0;i<n;i++)                a[i]=a[i]*u%p;        }    }    ll x[maxn];    ll y[maxn];    ll z[maxn];    void copy_clear(ll *a,ll *b,int m)    {        int i;        for(i=0;i<m;i++)            a[i]=b[i];        for(i=m;i<n;i++)            a[i]=0;    }    void copy(ll *a,ll *b,int m)    {        int i;        for(i=0;i<m;i++)            a[i]=b[i];    }    void mul(ll *a,ll *b,ll *c,int m)    {        init(m<<1);        copy_clear(x,a,m);        copy_clear(y,b,m);        ntt(x,1);        ntt(y,1);        int i;        for(i=0;i<n;i++)            x[i]=x[i]*y[i]%p;        ntt(x,-1);        copy(c,x,m);    }    void inverse(ll *a,ll *b,int m)    {        if(m==1)        {            b[0]=fp(a[0],p-2);            return;        }        inverse(a,b,m>>1);        init(m<<1);        copy_clear(x,a,m);        copy_clear(y,b,m>>1);        ntt(x,1);        ntt(y,1);        int i;        for(i=0;i<n;i++)            x[i]=y[i]*(2-x[i]*y[i]%p)%p;        ntt(x,-1);        copy(b,x,m);    }    ll c[maxn],d[maxn],e[maxn],f[maxn];    void sqrt(ll *a,ll *b,int m)    {        if(m==1)        {            if(a[0]==1)                b[0]=1;            else if(a[0]==0)                b[0]=0;            else                //我也不会                ;            return;        }        sqrt(a,b,m>>1);//      copy_clear(c,b,m>>1);        int i;        for(i=m;i<m<<1;i++)            b[i]=0;        inverse(b,d,m);        init(m<<1);        for(i=m;i<m<<1;i++)            b[i]=d[i]=0;        ll inv2=fp(2,p-2);        copy_clear(x,a,m);        ntt(x,1);        ntt(d,1);        for(i=0;i<n;i++)            x[i]=x[i]*d[i]%p;        ntt(x,-1);        for(i=0;i<m;i++)            b[i]=((b[i]+x[i])%p*inv2)%p;    }    void derivative(ll *a,ll *b,int m)    {        int i;        for(i=0;i<m-1;i++)            b[i]=(i+1)*a[i+1]%p;        b[m-1]=0;    }    void differential(ll *a,ll *b,int m)    {//      int i;//      for(i=m-1;i>=1;i--)//          b[i]=a[i-1]*inv[i]%p;        b[0]=0;    }    void ln(ll *a,ll *b,int m)    {        static ll c[maxn],d[maxn];        derivative(a,c,m);        inverse(a,d,m);        init(m<<1);        int i;        for(i=m;i<n;i++)            c[i]=d[i]=0;        ntt(c,1);        ntt(d,1);        for(i=0;i<n;i++)            c[i]=c[i]*d[i]%p;        ntt(c,-1);        differential(c,b,m);    }    void exp(ll *a,ll *b,int m)    {        if(m==1)        {            b[0]=1;            return;        }        exp(a,b,m>>1);        int i;        for(i=m>>1;i<m;i++)            b[i]=0;        ln(b,y,m);        init(m<<1);        copy_clear(x,a,m);        x[0]++;        for(i=0;i<m;i++)            x[i]=(x[i]-y[i])%p;        copy_clear(y,b,m);        ntt(x,1);        ntt(y,1);        for(i=0;i<n;i++)            x[i]=x[i]*y[i]%p;        ntt(x,-1);        copy(b,x,m);    }    void module(ll *a,ll *b,ll *c,int n1,int n2)    {        int k=1;        while(k<=n1-n2+1)            k<<=1;        int i;        for(i=0;i<=n1;i++)            d[i]=a[i];        for(i=0;i<=n2;i++)            e[i]=b[i];        reverse(d,d+n1+1);        reverse(e,e+n2+1);        for(i=n1-n2+1;i<k<<1;i++)            d[i]=e[i]=0;        inverse(e,f,k);        for(i=n1-n2+1;i<k<<1;i++)            f[i]=0;        init(k<<1);        ntt::ntt(d,1);        ntt::ntt(f,1);        for(i=0;i<n;i++)            e[i]=d[i]*f[i]%p;        ntt::ntt(e,-1);        for(i=0;i<=n1-n2;i++)            c[i]=e[i];        reverse(c,c+n1-n2+1);    }};void add(ll &a,ll b){    a=(a+b)%p;}ll inv(ll a){    return fp(a,p-2);}ll pw1[maxn];ll pw2[maxn];ll q;ll q2;ll f[maxn];ll a[maxn];ll c[maxn];ll d[maxn];ll final[maxn];ll g[2][maxn];ll h[maxn];ll e[maxn];void mul(ll *a,ll *b,ll *c,int n){    static ll d[maxn],e[maxn];    int k=1;    while(k<=n)        k<<=1;    ntt::init(k<<1);    int i;    for(i=0;i<k<<1;i++)        d[i]=e[i]=0;    for(i=0;i<=n;i++)    {        d[i]=a[i];        e[i]=b[i];    }    ntt::ntt(d,1);    ntt::ntt(e,1);    for(i=0;i<k<<1;i++)        d[i]=d[i]*e[i]%p;    ntt::ntt(d,-1);    //d=a*b    for(i=0;i<k<<1;i++)        e[i]=0;    int n2=(k<<1)-1;    while(!d[n2])        n2--;    ntt::module(d,c,e,n2,n);    for(i=0;i<n;i++)        a[i]=d[i];    for(i=0;i<k;i++)        d[i]=c[i];    for(i=k;i<k<<1;i++)        d[i]=0;    ntt::init(k<<1);    ntt::ntt(d,1);    ntt::ntt(e,1);    for(i=0;i<k<<1;i++)        d[i]=d[i]*e[i]%p;    ntt::ntt(d,-1);    for(i=0;i<n;i++)        a[i]=(a[i]-d[i])%p;}void powmod(ll *a,ll *b,ll *c,int m,int n){    if(!n)        return;    powmod(a,b,c,m,n>>1);    mul(a,a,c,m);    if(n&1)        mul(a,b,c,m);}ll solve(int n,int k){    memset(g,0,sizeof g);    memset(h,0,sizeof h);    int now=0;    g[now][1]=q2*pw1[k]%p;    g[now][0]=1;    h[0]=1;    int i,j;    for(i=k-1;i>=1;i--)    {        now^=1;        int m=k/i;        ll c=q2*pw1[i]%p;        int len=1;        while(len<=m)            len<<=1;        for(j=1;j<len;j++)            e[j]=-c*g[now^1][j-1];        e[0]=1;        ntt::inverse(e,h,len);        for(j=m+1;j<len<<1;j++)            h[j]=0;        ntt::init(len<<1);        ntt::ntt(g[now^1],1);        ntt::ntt(h,1);        for(j=0;j<len<<1;j++)            g[now][j]=g[now^1][j]*h[j]%p;        ntt::ntt(g[now],-1);        for(j=m+1;j<len<<1;j++)            g[now][j]=0;    }    memset(a,0,sizeof a);    for(i=0;i<=k;i++)        a[i+1]=-g[now][i]*q2%p;    a[0]=1;    int len=1;    while(len<=k+1)        len<<=1;    ntt::inverse(a,f,len<<1);    if(n<=2*(k+1))    {        ll s=0;        for(i=0;i<=n&&i<=k;i++)            add(s,f[n-i]*g[now][i]);        return s;    }    memset(a,0,sizeof a);    memset(c,0,sizeof c);    memset(d,0,sizeof d);    for(i=0;i<=k;i++)        a[i]=-g[now][k-i]*q2%p;    a[k+1]=1;    if(k)        c[1]=1;    else        c[0]=-a[0];    d[0]=1;    int m=n-k;    powmod(d,c,a,k+1,m);//  while(m)//  {//      if(m&1)//          mul(d,c,a,k+1);//      mul(c,c,a,k+1);//      m>>=1;////        for(i=0;i<=k;i++)////            printf("%lld ",(d[i]+p)%p);////        printf("\n");//  }    reverse(d,d+k+1);    ntt::init(len<<2);    ntt::ntt(d,1);    ntt::ntt(f,1);    for(i=0;i<len<<2;i++)        final[i]=d[i]*f[i]%p;    ntt::ntt(final,-1);    ll s=0;    for(i=0;i<=k;i++)        add(s,g[now][i]*final[2*k-i]);    return s;//  for(i=0;i<=k;i++)//      g[now][i]=(g[now][i]+p)%p;//  memset(f,0,sizeof f);//  f[0]=1;//  for(i=1;i<=2*(k+1);i++)//      for(j=0;j<i&&j<=k;j++)//          add(f[i],f[i-j-1]*q2%p*g[now][j]);//  if(n<=2*(k+1))//  {//      ll s=0;//      for(i=0;i<=n&&i<=k;i++)//          add(s,f[n-i]*g[now][i]);//      return s;//  }//  int len=k+1;//  for(i=0;i<len;i++)//      a[i]=-q2*g[now][len-i-1]%p;//  a[len]=1;//  memset(c,0,sizeof c);//  c[1]=1;//  memset(d,0,sizeof d);//  d[0]=1;//  int m=n-k-1;//  while(m)//  {//      if(m&1)//          mul(d,c,a,len);//      mul(c,c,a,len);//      m>>=1;//  }//  memset(final,0,sizeof final);//  for(i=1;i<=k+1;i++)//      for(j=0;j<=k;j++)//          add(final[i],d[j]*f[i+j]);//  ll s=0;//  for(i=1;i<=k+1;i++)//      add(s,final[i]*g[now][k+1-i]);//  return s;}int main(){    open("bzoj4944");    int n,k,x,y;    scanf("%d%d%d%d",&n,&k,&x,&y);    q=x*inv(y)%p;    q2=(y-x)*inv(y)%p;    pw1[0]=pw2[0]=1;    int i;    for(i=1;i<=k;i++)    {        pw1[i]=pw1[i-1]*q%p;        pw2[i]=pw2[i-1]*q2%p;    }    ll ans1=solve(n,k);    ll ans2=solve(n,k-1);    ll ans=((ans1-ans2)%p+p)%p;    printf("%lld\n",ans);    return 0;}
阅读全文
0 0