bzoj4818 [Sdoi2017]序列计数

来源:互联网 发布:丽江知乎 编辑:程序博客网 时间:2024/06/08 12:32

传送门

矩阵优化dp、容斥原理。
先写出dp转移方程:f[i&1][(j+k)%p][1]=f[(i&1)^1][j][0]+f[(i&1)^1][j][1] (这是我当时在考场上写的20分暴力)
上面的方程对有无素数进行了分类,其实可以进行容斥,用所有方案数减去没有素数的方案数。又因为题目要求序列之和为p的倍数,所以上面的方程对序列之和模p的余数进行了存储。而余数让我们想到了什么呢?矩阵!矩阵可以对不同的余数进行记录,并且可以很好地处理转移。题目中p的最大值为100,所以我们可以开100*100的矩阵,矩阵中存储余数为某一定值的时候的方案数。(这种方法是我自己yy的,所以我还是举个例子比较好)

以p=4为例:
这里写图片描述
上面左边的矩阵是转移矩阵,中间的是初始矩阵,右边的是结果矩阵,转移矩阵中写的数字是序列之和对p取模后的余数,但事实上转移矩阵中存储的是序列之和为该余数时的方案数,初始矩阵和结果矩阵中的f[i]表示序列之和模p后余数为i的方案数
我们可以看出,通过将初始矩阵不断地用转移矩阵进行矩阵乘法,就可以得到将序列增长后可行的方案数。所以我们对转移矩阵进行n-1次乘法,然后乘上初始矩阵就可以得出答案。这样求出所有方案数和没有素数的方案数就好了。

CODE:

#include<cstdio>#include<cstring>#define mod 20170408#define N 20000005int prime[1500000];bool b[N];int n,m,p,tot,ans;struct Matrix{    int a[105][105];    Matrix(){memset(a,0,sizeof(a));}    inline Matrix operator *(const Matrix &x)const    {        Matrix ans;        for(int i=0;i<p;i++)          for(int j=0;j<p;j++)            if(a[i][j]) for(int k=0;k<p;k++)              ans.a[i][k]=(1ll*ans.a[i][k]+1ll*a[i][j]*x.a[j][k])%mod;        return ans;    }}m1,m2,tmp1,tmp2;inline void euler(){    for(int i=2;i<=m;i++)    {        if(!b[i]) prime[++tot]=i;        for(int j=1;j<=tot&&i*prime[j]<=m;j++)        {            b[i*prime[j]]=1;            if(i%prime[j]==0) break;        }    }}inline void Matrix_init(){    int num=m/p,rest=m%p;    for(int i=0;i<p;i++)      m1.a[0][i]=num;    for(int i=p-1;rest;i--,rest--)      m1.a[0][i]++;    for(int i=1;i<p;i++)    {        m1.a[i][0]=m1.a[i-1][p-1];        for(int j=1;j<p;j++)          m1.a[i][j]=m1.a[i-1][j-1];    }    m2=m1;    for(int i=1;i<=tot;i++)    {        int rest=prime[i]%p;        int pos=p-rest;        if(!rest) pos=0;        for(int j=0;j<p;j++)        {            m2.a[j][pos]--;            pos++;            if(pos==p) pos=0;        }    }    tmp1.a[0][0]=m1.a[0][0];    tmp2.a[0][0]=m2.a[0][0];    for(int i=1,pos=p-1;i<p;i++,pos--)      tmp1.a[i][0]=m1.a[0][pos],      tmp2.a[i][0]=m2.a[0][pos];}inline Matrix ksm(Matrix a,int b){    Matrix ans=a;    for(b--;b;b>>=1,a=a*a)      if(b&1) ans=ans*a;    return ans;}int main(){    scanf("%d%d%d",&n,&m,&p);    euler(),Matrix_init();    m1=ksm(m1,n-1),m2=ksm(m2,n-1);    m1=m1*tmp1,m2=m2*tmp2;    ans=m1.a[0][0]-m2.a[0][0];    if(ans<0) ans+=mod;    printf("%d",ans);    return 0;}