[BZOJ4818][SDOI2017]序列计数(DP+容斥原理+矩乘)

来源:互联网 发布:全看网软件下载 编辑:程序博客网 时间:2024/05/29 12:29

1、大意做法

首先发现,DP的统计方案数中,「至少一个质数」的条件不容易直接DP,但是「不含任意一个质数」的条件容易DP。所以这里使用容斥原理,用总方案数减去不含任意一个质数的方案数。

2、预处理

cnt0,cnt1,cnt2,...cntp1分别为1...m的数中模p0,1,2,...p1的数的个数。
pri0,pri1,pri2,...prip1分别为1...m的数中的合数p0,1,2,...p1的数的个数。

3、建立朴素DP

f[i][j]表示i个数的序列,总和模pj的方案总数,
g[i][j]i个数的序列,总和模pj且不含质数的方案总数。
可以推出DP方程:
f[i][j]=p1k=0(f[i1][(jk+p)modp]cntk)
g[i][j]=p1k=0(g[i1][(jk+p)modp]prik)
边界为f[1][j]=cntjg[1][j]=prij
最后结果为f[n][0]g[n][0]

4、矩阵优化

可以发现,上面的朴素DP无法承受n<=109的数据范围。
这里考虑构建pp的矩阵PQp1的矩阵VW
矩阵P的内容为:

cnt0cnt1cnt2cnt3...cntp1cntp1cnt0cnt1cnt2...cntp2cntp2cntp1cnt0cnt1...cntp3cntp3cntp2cntp1cnt0...cntp4cntp4cntp3cntp2cntp1...cntp5..................cnt1cnt2cnt3cnt4...cnt0

矩阵Q的内容为:
pri0pri1pri2pri3...prip1prip1pri0pri1pri2...prip2prip2prip1pri0pri1...prip3prip3prip2prip1pri0...prip4prip4prip3prip2prip1...prip5..................pri1pri2pri3pri4...pri0

矩阵V的内容为:
cnt0cnt1cnt2...cntp1

矩阵W的内容为:
pri0pri1pri2...prip1

然后进行矩阵乘方:
F=Pn1V
G=Qn1W
最后结果即为F[1][1]G[1][1]
以上操作注意取模。
代码:

#include <cmath>#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>using namespace std;const int N = 105, CYX = 20170408, M = 2e7 + 5;struct cyx {    int m, n, v[N][N];    cyx() {}    cyx(int _m, int _n) :        m(_m), n(_n) {memset(v, 0, sizeof(v));}    friend inline cyx operator * (cyx a, cyx b) {        cyx res = cyx(a.m, b.n);        int i, j, k;        for (i = 1; i <= res.m; i++)            for (j = 1; j <= res.n; j++)                for (k = 1; k <= a.n; k++)                    res.v[i][j] = (res.v[i][j] + 1ll * a.v[i][k]                        * b.v[k][j] % CYX) % CYX;        return res;    }    friend inline cyx operator ^ (cyx a, int b) {        cyx c = a, res = cyx(a.m, a.n);        int i, j; for (i = 1; i <= res.m; i++)            res.v[i][i] = 1;        while (b) {            if (b & 1) res = res * c;            c = c * c;            b >>= 1;        }        return res;    }} P, Q, V, W;int n0, m0, p0, cnt[N], cnt0[N];bool prime[M];inline int read() {    int res = 0; bool bo = 0; char c;    while (((c = getchar()) < '0' || c > '9') && c != '-');    if (c == '-') bo = 1; else res = c - 48;    while ((c = getchar()) >= '0' && c <= '9')        res = (res << 3) + (res << 1) + (c - 48);    return bo ? ~res + 1 : res;}int main() {    int i, j; n0 = read(); m0 = read(); p0 = read();    for (i = 0; i < p0; i++) cnt[i] = m0 / p0;    for (i = 1; i <= m0 % p0; i++) cnt[i]++;    P = cyx(p0, p0); P.v[1][1] = cnt[0];    for (i = 2; i <= p0; i++) P.v[1][i] = cnt[p0 - i + 1];    for (i = 2; i <= p0; i++) {        for (j = 2; j <= p0; j++)            P.v[i][j] = P.v[i - 1][j - 1];        P.v[i][1] = P.v[i - 1][p0];    }    memset(prime, true, sizeof(prime));    prime[0] = prime[1] = 0;    for (i = 2; i <= m0; i++) {        if (!prime[i]) continue;        if (i * i > m0) break;        for (j = i * i; j <= m0; j += i)            prime[j] = 0;    }    for (i = 1; i <= m0; i++)        if (!prime[i]) cnt0[i % p0]++;    Q = cyx(p0, p0); Q.v[1][1] = cnt0[0];    for (i = 2; i <= p0; i++) Q.v[1][i] = cnt0[p0 - i + 1];    for (i = 2; i <= p0; i++) {        for (j = 2; j <= p0; j++)            Q.v[i][j] = Q.v[i - 1][j - 1];        Q.v[i][1] = Q.v[i - 1][p0];    }    for (i = 1; i <= p0; i++) for (j = 1; j <= p0; j++)        P.v[i][j] %= CYX, Q.v[i][j] %= CYX;    V = cyx(p0, 1); W = cyx(p0, 1);    for (i = 1; i <= p0; i++) V.v[i][1] = cnt[i - 1] % CYX;    for (i = 1; i <= p0; i++) W.v[i][1] = cnt0[i - 1] % CYX;    P = (P ^ n0 - 1) * V; Q = (Q ^ n0 - 1) * W;    printf("%d\n", (P.v[1][1] - Q.v[1][1] + CYX) % CYX);    return 0;}
原创粉丝点击